---
[](https://github.com/Kludex/starlette/actions)
[](https://pypi.python.org/pypi/starlette)
[](https://pypi.org/project/starlette)
[](https://discord.gg/RxKUF5JuHs)
---
**Documentation**: https://starlette.dev
**Source Code**: https://github.com/Kludex/starlette
---
# Starlette
Starlette is a lightweight [ASGI][asgi] framework/toolkit,
which is ideal for building async web services in Python.
It is production-ready, and gives you the following:
* A lightweight, low-complexity HTTP web framework.
* WebSocket support.
* In-process background tasks.
* Startup and shutdown events.
* Test client built on `httpx`.
* CORS, GZip, Static Files, Streaming responses.
* Session and Cookie support.
* 100% test coverage.
* 100% type annotated codebase.
* Few hard dependencies.
* Compatible with `asyncio` and `trio` backends.
* Great overall performance [against independent benchmarks][techempower].
## Installation
```shell
$ pip install starlette
```
You'll also want to install an ASGI server, such as [uvicorn](https://www.uvicorn.org/), [daphne](https://github.com/django/daphne/), or [hypercorn](https://hypercorn.readthedocs.io/en/latest/).
```shell
$ pip install uvicorn
```
## Example
```python title="main.py"
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
async def homepage(request):
return JSONResponse({'hello': 'world'})
routes = [
Route("/", endpoint=homepage)
]
app = Starlette(debug=True, routes=routes)
```
Then run the application using Uvicorn:
```shell
$ uvicorn main:app
```
## Dependencies
Starlette only requires `anyio`, and the following are optional:
* [`httpx`][httpx] - Required if you want to use the `TestClient`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
* [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support.
You can install all of these with `pip install starlette[full]`.
## Framework or Toolkit
Starlette is designed to be used either as a complete framework, or as
an ASGI toolkit. You can use any of its components independently.
```python
from starlette.responses import PlainTextResponse
async def app(scope, receive, send):
assert scope['type'] == 'http'
response = PlainTextResponse('Hello, world!')
await response(scope, receive, send)
```
Run the `app` application in `example.py`:
```shell
$ uvicorn example:app
INFO: Started server process [11509]
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
```
Run uvicorn with `--reload` to enable auto-reloading on code changes.
## Modularity
The modularity that Starlette is designed on promotes building re-usable
components that can be shared between any ASGI framework. This should enable
an ecosystem of shared middleware and mountable applications.
The clean API separation also means it's easier to understand each component
in isolation.
---
Starlette is BSD licensed code. Designed & crafted with care.— ⭐️ —
[asgi]: https://asgi.readthedocs.io/en/latest/
[httpx]: https://www.python-httpx.org/
[jinja2]: https://jinja.palletsprojects.com/
[python-multipart]: https://multipart.fastapiexpert.com/
[itsdangerous]: https://itsdangerous.palletsprojects.com/
[sqlalchemy]: https://www.sqlalchemy.org
[pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation
[techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf
================================================
FILE: docs/CNAME
================================================
www.starlette.io
================================================
FILE: docs/applications.md
================================================
??? abstract "API Reference"
::: starlette.applications.Starlette
options:
parameter_headings: false
show_root_heading: true
heading_level: 3
filters:
- "__init__"
Starlette includes an application class `Starlette` that nicely ties together all of
its other functionality.
```python
from contextlib import asynccontextmanager
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Mount, WebSocketRoute
from starlette.staticfiles import StaticFiles
def homepage(request):
return PlainTextResponse('Hello, world!')
def user_me(request):
username = "John Doe"
return PlainTextResponse('Hello, %s!' % username)
def user(request):
username = request.path_params['username']
return PlainTextResponse('Hello, %s!' % username)
async def websocket_endpoint(websocket):
await websocket.accept()
await websocket.send_text('Hello, websocket!')
await websocket.close()
@asynccontextmanager
async def lifespan(app):
print('Startup')
yield
print('Shutdown')
routes = [
Route('/', homepage),
Route('/user/me', user_me),
Route('/user/{username}', user),
WebSocketRoute('/ws', websocket_endpoint),
Mount('/static', StaticFiles(directory="static")),
]
app = Starlette(debug=True, routes=routes, lifespan=lifespan)
```
### Storing state on the app instance
You can store arbitrary extra state on the application instance, using the
generic `app.state` attribute.
For example:
```python
app.state.ADMIN_EMAIL = 'admin@example.org'
```
### Accessing the app instance
Where a `request` is available (i.e. endpoints and middleware), the app is available on `request.app`.
================================================
FILE: docs/authentication.md
================================================
Starlette offers a simple but powerful interface for handling authentication
and permissions. Once you've installed `AuthenticationMiddleware` with an
appropriate authentication backend the `request.user` and `request.auth`
interfaces will be available in your endpoints.
```python
from starlette.applications import Starlette
from starlette.authentication import (
AuthCredentials, AuthenticationBackend, AuthenticationError, SimpleUser
)
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route
import base64
import binascii
class BasicAuthBackend(AuthenticationBackend):
async def authenticate(self, conn):
if "Authorization" not in conn.headers:
return
auth = conn.headers["Authorization"]
try:
scheme, credentials = auth.split()
if scheme.lower() != 'basic':
return
decoded = base64.b64decode(credentials).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error) as exc:
raise AuthenticationError('Invalid basic auth credentials')
username, _, password = decoded.partition(":")
# TODO: You'd want to verify the username and password here.
return AuthCredentials(["authenticated"]), SimpleUser(username)
async def homepage(request):
if request.user.is_authenticated:
return PlainTextResponse('Hello, ' + request.user.display_name)
return PlainTextResponse('Hello, you')
routes = [
Route("/", endpoint=homepage)
]
middleware = [
Middleware(AuthenticationMiddleware, backend=BasicAuthBackend())
]
app = Starlette(routes=routes, middleware=middleware)
```
## Users
Once `AuthenticationMiddleware` is installed the `request.user` interface
will be available to endpoints or other middleware.
This interface should subclass `BaseUser`, which provides two properties,
as well as whatever other information your user model includes.
* `.is_authenticated`
* `.display_name`
Starlette provides two built-in user implementations: `UnauthenticatedUser()`,
and `SimpleUser(username)`.
## AuthCredentials
It is important that authentication credentials are treated as separate concept
from users. An authentication scheme should be able to restrict or grant
particular privileges independently of the user identity.
The `AuthCredentials` class provides the basic interface that `request.auth`
exposes:
* `.scopes`
## Permissions
Permissions are implemented as an endpoint decorator, that enforces that the
incoming request includes the required authentication scopes.
```python
from starlette.authentication import requires
@requires('authenticated')
async def dashboard(request):
...
```
You can include either one or multiple required scopes:
```python
from starlette.authentication import requires
@requires(['authenticated', 'admin'])
async def dashboard(request):
...
```
By default 403 responses will be returned when permissions are not granted.
In some cases you might want to customize this, for example to hide information
about the URL layout from unauthenticated users.
```python
from starlette.authentication import requires
@requires(['authenticated', 'admin'], status_code=404)
async def dashboard(request):
...
```
!!! note
The `status_code` parameter is not supported with WebSockets. The 403 (Forbidden)
status code will always be used for those.
Alternatively you might want to redirect unauthenticated users to a different
page.
```python
from starlette.authentication import requires
async def homepage(request):
...
@requires('authenticated', redirect='homepage')
async def dashboard(request):
...
```
When redirecting users, the page you redirect them to will include URL they originally requested at the `next` query param:
```python
from starlette.authentication import requires
from starlette.responses import RedirectResponse
@requires('authenticated', redirect='login')
async def admin(request):
...
async def login(request):
if request.method == "POST":
# Now that the user is authenticated,
# we can send them to their original request destination
if request.user.is_authenticated:
next_url = request.query_params.get("next")
if next_url:
return RedirectResponse(next_url)
return RedirectResponse("/")
```
For class-based endpoints, you should wrap the decorator
around a method on the class.
```python
from starlette.authentication import requires
from starlette.endpoints import HTTPEndpoint
class Dashboard(HTTPEndpoint):
@requires("authenticated")
async def get(self, request):
...
```
## Custom authentication error responses
You can customise the error response sent when a `AuthenticationError` is
raised by an auth backend:
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
def on_auth_error(request: Request, exc: Exception):
return JSONResponse({"error": str(exc)}, status_code=401)
app = Starlette(
middleware=[
Middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error),
],
)
```
================================================
FILE: docs/background.md
================================================
Starlette includes a `BackgroundTask` class for in-process background tasks.
A background task should be attached to a response, and will run only once
the response has been sent.
### Background Task
Used to add a single background task to a response.
Signature: `BackgroundTask(func, *args, **kwargs)`
```python
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.background import BackgroundTask
...
async def signup(request):
data = await request.json()
username = data['username']
email = data['email']
task = BackgroundTask(send_welcome_email, to_address=email)
message = {'status': 'Signup successful'}
return JSONResponse(message, background=task)
async def send_welcome_email(to_address):
...
routes = [
...
Route('/user/signup', endpoint=signup, methods=['POST'])
]
app = Starlette(routes=routes)
```
### BackgroundTasks
Used to add multiple background tasks to a response.
Signature: `BackgroundTasks(tasks=[])`
```python
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.background import BackgroundTasks
async def signup(request):
data = await request.json()
username = data['username']
email = data['email']
tasks = BackgroundTasks()
tasks.add_task(send_welcome_email, to_address=email)
tasks.add_task(send_admin_notification, username=username)
message = {'status': 'Signup successful'}
return JSONResponse(message, background=tasks)
async def send_welcome_email(to_address):
...
async def send_admin_notification(username):
...
routes = [
Route('/user/signup', endpoint=signup, methods=['POST'])
]
app = Starlette(routes=routes)
```
!!! important
The tasks are executed in order. In case one of the tasks raises
an exception, the following tasks will not get the opportunity to be executed.
================================================
FILE: docs/config.md
================================================
Starlette encourages a strict separation of configuration from code,
following [the twelve-factor pattern][twelve-factor].
Configuration should be stored in environment variables, or in a `.env` file
that is not committed to source control.
```python title="main.py"
from sqlalchemy import create_engine
from starlette.applications import Starlette
from starlette.config import Config
from starlette.datastructures import CommaSeparatedStrings, Secret
# Config will be read from environment variables and/or ".env" files.
config = Config(".env")
DEBUG = config('DEBUG', cast=bool, default=False)
DATABASE_URL = config('DATABASE_URL')
SECRET_KEY = config('SECRET_KEY', cast=Secret)
ALLOWED_HOSTS = config('ALLOWED_HOSTS', cast=CommaSeparatedStrings)
app = Starlette(debug=DEBUG)
engine = create_engine(DATABASE_URL)
...
```
```shell title=".env"
# Don't commit this to source control.
# Eg. Include ".env" in your `.gitignore` file.
DEBUG=True
DATABASE_URL=postgresql://user:password@localhost:5432/database
SECRET_KEY=43n080musdfjt54t-09sdgr
ALLOWED_HOSTS=127.0.0.1, localhost
```
## Configuration precedence
The order in which configuration values are read is:
* From an environment variable.
* From the `.env` file.
* The default value given in `config`.
If none of those match, then `config(...)` will raise an error.
## Secrets
For sensitive keys, the `Secret` class is useful, since it helps minimize
occasions where the value it holds could leak out into tracebacks or
other code introspection.
To get the value of a `Secret` instance, you must explicitly cast it to a string.
You should only do this at the point at which the value is used.
```python
>>> from myproject import settings
>>> settings.SECRET_KEY
Secret('**********')
>>> str(settings.SECRET_KEY)
'98n349$%8b8-7yjn0n8y93T$23r'
```
!!! tip
You can use `DatabaseURL` from `databases`
package [here](https://github.com/encode/databases/blob/ab5eb718a78a27afe18775754e9c0fa2ad9cd211/databases/core.py#L420)
to store database URLs and avoid leaking them in the logs.
## CommaSeparatedStrings
For holding multiple inside a single config key, the `CommaSeparatedStrings`
type is useful.
```python
>>> from myproject import settings
>>> print(settings.ALLOWED_HOSTS)
CommaSeparatedStrings(['127.0.0.1', 'localhost'])
>>> print(list(settings.ALLOWED_HOSTS))
['127.0.0.1', 'localhost']
>>> print(len(settings.ALLOWED_HOSTS))
2
>>> print(settings.ALLOWED_HOSTS[0])
'127.0.0.1'
```
## Reading or modifying the environment
In some cases you might want to read or modify the environment variables programmatically.
This is particularly useful in testing, where you may want to override particular
keys in the environment.
Rather than reading or writing from `os.environ`, you should use Starlette's
`environ` instance. This instance is a mapping onto the standard `os.environ`
that additionally protects you by raising an error if any environment variable
is set *after* the point that it has already been read by the configuration.
If you're using `pytest`, then you can setup any initial environment in
`tests/conftest.py`.
```python title="tests/conftest.py"
from starlette.config import environ
environ['DEBUG'] = 'TRUE'
```
## Reading prefixed environment variables
You can namespace the environment variables by setting `env_prefix` argument.
```python title="myproject/settings.py"
import os
from starlette.config import Config
os.environ['APP_DEBUG'] = 'yes'
os.environ['ENVIRONMENT'] = 'dev'
config = Config(env_prefix='APP_')
DEBUG = config('DEBUG') # lookups APP_DEBUG, returns "yes"
ENVIRONMENT = config('ENVIRONMENT') # lookups APP_ENVIRONMENT, raises KeyError as variable is not defined
```
## Custom encoding for environment files
By default, Starlette reads environment files using UTF-8 encoding.
You can specify a different encoding by setting `encoding` argument.
```python title="myproject/settings.py"
from starlette.config import Config
# Using custom encoding for .env file
config = Config(".env", encoding="latin-1")
```
## A full example
Structuring large applications can be complex. You need proper separation of
configuration and code, database isolation during tests, separate test and
production databases, etc...
Here we'll take a look at a complete example, that demonstrates how
we can start to structure an application.
First, let's keep our settings, our database table definitions, and our
application logic separated:
```python title="myproject/settings.py"
from starlette.config import Config
from starlette.datastructures import Secret
config = Config(".env")
DEBUG = config('DEBUG', cast=bool, default=False)
SECRET_KEY = config('SECRET_KEY', cast=Secret)
DATABASE_URL = config('DATABASE_URL')
```
```python title="myproject/tables.py"
import sqlalchemy
# Database table definitions.
metadata = sqlalchemy.MetaData()
organisations = sqlalchemy.Table(
...
)
```
```python title="myproject/app.py"
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.routing import Route
from myproject import settings
async def homepage(request):
...
routes = [
Route("/", endpoint=homepage)
]
middleware = [
Middleware(
SessionMiddleware,
secret_key=settings.SECRET_KEY,
)
]
app = Starlette(debug=settings.DEBUG, routes=routes, middleware=middleware)
```
Now let's deal with our test configuration.
We'd like to create a new test database every time the test suite runs,
and drop it once the tests complete. We'd also like to ensure
```python title="tests/conftest.py"
from starlette.config import environ
from starlette.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy_utils import create_database, database_exists, drop_database
# This line would raise an error if we use it after 'settings' has been imported.
environ['DEBUG'] = 'TRUE'
from myproject import settings
from myproject.app import app
from myproject.tables import metadata
@pytest.fixture(autouse=True, scope="session")
def setup_test_database():
"""
Create a clean test database every time the tests are run.
"""
url = settings.DATABASE_URL
engine = create_engine(url)
assert not database_exists(url), 'Test database already exists. Aborting tests.'
create_database(url) # Create the test database.
metadata.create_all(engine) # Create the tables.
yield # Run the tests.
drop_database(url) # Drop the test database.
@pytest.fixture()
def client():
"""
Make a 'client' fixture available to test cases.
"""
# Our fixture is created within a context manager. This ensures that
# application lifespan runs for every test case.
with TestClient(app) as test_client:
yield test_client
```
[twelve-factor]: https://12factor.net/config
================================================
FILE: docs/contributing.md
================================================
# Contributing
Thank you for being interested in contributing to Starlette.
There are many ways you can contribute to the project:
- Try Starlette and [report bugs/issues you find](https://github.com/Kludex/starlette/issues/new)
- [Implement new features](https://github.com/Kludex/starlette/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
- [Review Pull Requests of others](https://github.com/Kludex/starlette/pulls)
- Write documentation
- Participate in discussions
## Reporting Bugs or Other Issues
Found something that Starlette should support?
Stumbled upon some unexpected behaviour?
Contributions should generally start out with [a discussion](https://github.com/Kludex/starlette/discussions).
Possible bugs may be raised as a "Potential Issue" discussion, feature requests may
be raised as an "Ideas" discussion. We can then determine if the discussion needs
to be escalated into an "Issue" or not, or if we'd consider a pull request.
Try to be more descriptive as you can and in case of a bug report,
provide as much information as possible like:
- OS platform
- Python version
- Installed dependencies and versions (`python -m pip freeze`)
- Code snippet
- Error traceback
You should always try to reduce any examples to the *simplest possible case*
that demonstrates the issue.
## Development
To start developing Starlette, create a **fork** of the
[Starlette repository](https://github.com/Kludex/starlette) on GitHub.
Then clone your fork with the following command replacing `YOUR-USERNAME` with
your GitHub username:
```shell
$ git clone https://github.com/YOUR-USERNAME/starlette
```
You can now install the project and its dependencies using:
```shell
$ cd starlette
$ scripts/install
```
## Testing and Linting
We use custom shell scripts to automate testing, linting,
and documentation building workflow.
To run the tests, use:
```shell
$ scripts/test
```
Any additional arguments will be passed to `pytest`. See the [pytest documentation](https://docs.pytest.org/en/latest/how-to/usage.html) for more information.
For example, to run a single test script:
```shell
$ scripts/test tests/test_application.py
```
To run the code auto-formatting:
```shell
$ scripts/lint
```
Lastly, to run code checks separately (they are also run as part of `scripts/test`), run:
```shell
$ scripts/check
```
## Documenting
Documentation pages are located under the `docs/` folder.
To run the documentation site locally (useful for previewing changes), use:
```shell
$ scripts/docs
```
## Resolving Build / CI Failures
Once you've submitted your pull request, the test suite will automatically run, and the results will show up in GitHub.
If the test suite fails, you'll want to click through to the "Details" link, and try to identify why the test suite failed.
Here are some common ways the test suite can fail:
### Check Job Failed
This job failing means there is either a code formatting issue or type-annotation issue.
You can look at the job output to figure out why it's failed or within a shell run:
```shell
$ scripts/check
```
It may be worth it to run `$ scripts/lint` to attempt auto-formatting the code
and if that job succeeds commit the changes.
### Docs Job Failed
This job failing means the documentation failed to build. This can happen for
a variety of reasons like invalid markdown or missing configuration within `mkdocs.yml`.
### Python 3.X Job Failed
This job failing means the unit tests failed or not all code paths are covered by unit tests.
If tests are failing you will see this message under the coverage report:
`=== 1 failed, 435 passed, 1 skipped, 1 xfailed in 11.09s ===`
If tests succeed but coverage doesn't reach our current threshold, you will see this
message under the coverage report:
`FAIL Required test coverage of 100% not reached. Total coverage: 99.00%`
## Releasing
*This section is targeted at Starlette maintainers.*
Before releasing a new version, create a pull request that includes:
- **An update to the changelog**:
- We follow the format from [keepachangelog](https://keepachangelog.com/en/1.0.0/).
- [Compare](https://github.com/Kludex/starlette/compare/) `main` with the tag of the latest release, and list all entries that are of interest to our users:
- Things that **must** go in the changelog: added, changed, deprecated or removed features, and bug fixes.
- Things that **should not** go in the changelog: changes to documentation, tests or tooling.
- Try sorting entries in descending order of impact / importance.
- Keep it concise and to-the-point. 🎯
- **A version bump**: see `__version__.py`.
For an example, see [#1600](https://github.com/Kludex/starlette/pull/1600).
Once the release PR is merged, create a
[new release](https://github.com/Kludex/starlette/releases/new) including:
- Tag version like `0.13.3`.
- Release title `Version 0.13.3`
- Description copied from the changelog.
Once created this release will be automatically uploaded to PyPI.
================================================
FILE: docs/css/custom.css
================================================
/* Lighter dark mode colors */
[data-md-color-scheme="slate"] {
--md-default-bg-color: #263238;
--md-default-fg-color: #e0e0e0;
--md-code-bg-color: #2e3c43;
}
/* Announcement bar styling */
.announce-wrapper {
display: flex;
justify-content: center;
align-items: center;
height: 40px;
min-height: 40px;
background-color: var(--md-primary-fg-color);
}
.announce-wrapper #announce-msg {
display: flex;
align-items: center;
justify-content: center;
}
.announce-wrapper #announce-msg div.item {
display: none;
}
.announce-wrapper #announce-msg div.item:first-child {
display: block;
}
a.announce-link:link,
a.announce-link:visited {
color: var(--md-primary-bg-color);
text-decoration: none;
font-weight: 500;
}
a.announce-link:hover {
color: var(--md-accent-fg-color);
text-decoration: underline;
}
================================================
FILE: docs/database.md
================================================
Starlette is not strictly tied to any particular database implementation.
You are free to use any async database library that you prefer. Some popular options include:
- [SQLAlchemy](https://www.sqlalchemy.org/) - The Python SQL toolkit with native async support (2.0+).
- [SQLModel](https://sqlmodel.tiangolo.com/) - SQL databases in Python, designed for simplicity, built on top of SQLAlchemy and Pydantic.
- [Tortoise ORM](https://tortoise.github.io/) - An easy-to-use asyncio ORM inspired by Django.
- [Piccolo](https://piccolo-orm.com/) - A fast, user-friendly ORM and query builder.
Refer to your chosen database library's documentation for specific connection and query patterns.
================================================
FILE: docs/endpoints.md
================================================
Starlette includes the classes `HTTPEndpoint` and `WebSocketEndpoint` that provide a class-based view pattern for
handling HTTP method dispatching and WebSocket sessions.
### HTTPEndpoint
The `HTTPEndpoint` class can be used as an ASGI application:
```python
from starlette.responses import PlainTextResponse
from starlette.endpoints import HTTPEndpoint
class App(HTTPEndpoint):
async def get(self, request):
return PlainTextResponse(f"Hello, world!")
```
If you're using a Starlette application instance to handle routing, you can
dispatch to an `HTTPEndpoint` class. Make sure to dispatch to the class itself,
rather than to an instance of the class:
```python
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from starlette.endpoints import HTTPEndpoint
from starlette.routing import Route
class Homepage(HTTPEndpoint):
async def get(self, request):
return PlainTextResponse(f"Hello, world!")
class User(HTTPEndpoint):
async def get(self, request):
username = request.path_params['username']
return PlainTextResponse(f"Hello, {username}")
routes = [
Route("/", Homepage),
Route("/{username}", User)
]
app = Starlette(routes=routes)
```
HTTP endpoint classes will respond with "405 Method not allowed" responses for any
request methods which do not map to a corresponding handler.
### WebSocketEndpoint
The `WebSocketEndpoint` class is an ASGI application that presents a wrapper around
the functionality of a `WebSocket` instance.
The ASGI connection scope is accessible on the endpoint instance via `.scope` and
has an attribute `encoding` which may optionally be set, in order to validate the expected websocket data in the `on_receive` method.
The encoding types are:
* `'json'`
* `'bytes'`
* `'text'`
There are three overridable methods for handling specific ASGI websocket message types:
* `async def on_connect(websocket, **kwargs)`
* `async def on_receive(websocket, data)`
* `async def on_disconnect(websocket, close_code)`
```python
from starlette.endpoints import WebSocketEndpoint
class App(WebSocketEndpoint):
encoding = 'bytes'
async def on_connect(self, websocket):
await websocket.accept()
async def on_receive(self, websocket, data):
await websocket.send_bytes(b"Message: " + data)
async def on_disconnect(self, websocket, close_code):
pass
```
The `WebSocketEndpoint` can also be used with the `Starlette` application class:
```python
import uvicorn
from starlette.applications import Starlette
from starlette.endpoints import WebSocketEndpoint, HTTPEndpoint
from starlette.responses import HTMLResponse
from starlette.routing import Route, WebSocketRoute
html = """
Chat
WebSocket Chat
"""
class Homepage(HTTPEndpoint):
async def get(self, request):
return HTMLResponse(html)
class Echo(WebSocketEndpoint):
encoding = "text"
async def on_receive(self, websocket, data):
await websocket.send_text(f"Message text was: {data}")
routes = [
Route("/", Homepage),
WebSocketRoute("/ws", Echo)
]
app = Starlette(routes=routes)
```
================================================
FILE: docs/exceptions.md
================================================
Starlette allows you to install custom exception handlers to deal with
how you return responses when errors or handled exceptions occur.
```python
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import HTMLResponse
HTML_404_PAGE = ...
HTML_500_PAGE = ...
async def not_found(request: Request, exc: HTTPException):
return HTMLResponse(content=HTML_404_PAGE, status_code=exc.status_code)
async def server_error(request: Request, exc: HTTPException):
return HTMLResponse(content=HTML_500_PAGE, status_code=exc.status_code)
exception_handlers = {
404: not_found,
500: server_error
}
app = Starlette(routes=routes, exception_handlers=exception_handlers)
```
If `debug` is enabled and an error occurs, then instead of using the installed
500 handler, Starlette will respond with a traceback response.
```python
app = Starlette(debug=True, routes=routes, exception_handlers=exception_handlers)
```
As well as registering handlers for specific status codes, you can also
register handlers for classes of exceptions.
In particular you might want to override how the built-in `HTTPException` class
is handled. For example, to use JSON style responses:
```python
async def http_exception(request: Request, exc: HTTPException):
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
exception_handlers = {
HTTPException: http_exception
}
```
The `HTTPException` is also equipped with the `headers` argument. Which allows the propagation
of the headers to the response class:
```python
async def http_exception(request: Request, exc: HTTPException):
return JSONResponse(
{"detail": exc.detail},
status_code=exc.status_code,
headers=exc.headers
)
```
You might also want to override how `WebSocketException` is handled:
```python
async def websocket_exception(websocket: WebSocket, exc: WebSocketException):
await websocket.close(code=1008)
exception_handlers = {
WebSocketException: websocket_exception
}
```
## Errors and handled exceptions
It is important to differentiate between handled exceptions and errors.
Handled exceptions do not represent error cases. They are coerced into appropriate
HTTP responses, which are then sent through the standard middleware stack. By default
the `HTTPException` class is used to manage any handled exceptions.
Errors are any other exception that occurs within the application. These cases
should bubble through the entire middleware stack as exceptions. Any error
logging middleware should ensure that it re-raises the exception all the
way up to the server.
In practical terms, the error handled used is `exception_handler[500]` or `exception_handler[Exception]`.
Both keys `500` and `Exception` can be used. See below:
```python
async def handle_error(request: Request, exc: HTTPException):
# Perform some logic
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
exception_handlers = {
Exception: handle_error # or "500: handle_error"
}
```
It's important to notice that in case a [`BackgroundTask`](background.md) raises an exception,
it will be handled by the `handle_error` function, but at that point, the response was already sent. In other words,
the response created by `handle_error` will be discarded. In case the error happens before the response was sent, then
it will use the response object - in the above example, the returned `JSONResponse`.
In order to deal with this behaviour correctly, the middleware stack of a
`Starlette` application is configured like this:
* `ServerErrorMiddleware` - Returns 500 responses when server errors occur.
* Installed middleware
* `ExceptionMiddleware` - Deals with handled exceptions, and returns responses.
* Router
* Endpoints
## HTTPException
The `HTTPException` class provides a base class that you can use for any handled exceptions.
The `ExceptionMiddleware` implementation defaults to returning plain-text HTTP responses for any `HTTPException`.
* `HTTPException(status_code, detail=None, headers=None)`
You should only raise `HTTPException` inside routing or endpoints.
Middleware classes should instead just return appropriate responses directly.
You can use an `HTTPException` on a WebSocket endpoint. In case it's raised before `websocket.accept()`
the connection is not upgraded to a WebSocket connection, and the proper HTTP response is returned.
```python
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.routing import WebSocketRoute
from starlette.websockets import WebSocket
async def websocket_endpoint(websocket: WebSocket):
raise HTTPException(status_code=400, detail="Bad request")
app = Starlette(routes=[WebSocketRoute("/ws", websocket_endpoint)])
```
## WebSocketException
You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints.
* `WebSocketException(code=1008, reason=None)`
You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1).
================================================
FILE: docs/graphql.md
================================================
GraphQL support in Starlette was deprecated in version 0.15.0, and removed in version 0.17.0.
Although GraphQL support is no longer built in to Starlette, you can still use GraphQL with Starlette via 3rd party libraries. These libraries all have Starlette-specific guides to help you do just that:
- [Ariadne](https://ariadnegraphql.org/docs/starlette-integration.html)
- [`starlette-graphene3`](https://github.com/ciscorn/starlette-graphene3#example)
- [Strawberry](https://strawberry.rocks/docs/integrations/starlette)
- [`tartiflette-asgi`](https://tartiflette.github.io/tartiflette-asgi/usage/#starlette)
================================================
FILE: docs/index.md
================================================
✨ The little ASGI framework that shines. ✨
---
**Documentation**: https://starlette.dev
**Source Code**: https://github.com/Kludex/starlette
---
# Introduction
Starlette is a lightweight [ASGI][asgi] framework/toolkit,
which is ideal for building async web services in Python.
It is production-ready, and gives you the following:
* A lightweight, low-complexity HTTP web framework.
* WebSocket support.
* In-process background tasks.
* Startup and shutdown events.
* Test client built on `httpx`.
* CORS, GZip, Static Files, Streaming responses.
* Session and Cookie support.
* 100% test coverage.
* 100% type annotated codebase.
* Few hard dependencies.
* Compatible with `asyncio` and `trio` backends.
* Great overall performance [against independent benchmarks][techempower].
## Sponsorship
Help us keep Starlette maintained and sustainable by [becoming a sponsor](https://github.com/sponsors/Kludex).
**Current sponsors:**
## Installation
```shell
pip install starlette
```
You'll also want to install an ASGI server, such as [uvicorn](https://www.uvicorn.org/), [daphne](https://github.com/django/daphne/), or [hypercorn](https://hypercorn.readthedocs.io/en/latest/).
```shell
pip install uvicorn
```
## Example
```python title="main.py"
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
async def homepage(request):
return JSONResponse({'hello': 'world'})
app = Starlette(debug=True, routes=[
Route('/', homepage),
])
```
Then run the application...
```shell
uvicorn main:app
```
## Dependencies
Starlette only requires `anyio`, and the following dependencies are optional:
* [`httpx`][httpx] - Required if you want to use the `TestClient`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
* [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support.
You can install all of these with `pip install starlette[full]`.
## Framework or Toolkit
Starlette is designed to be used either as a complete framework, or as
an ASGI toolkit. You can use any of its components independently.
```python title="main.py"
from starlette.responses import PlainTextResponse
async def app(scope, receive, send):
assert scope['type'] == 'http'
response = PlainTextResponse('Hello, world!')
await response(scope, receive, send)
```
Run the `app` application in `main.py`:
```shell
$ uvicorn main:app
INFO: Started server process [11509]
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
```
Run uvicorn with `--reload` to enable auto-reloading on code changes.
## Modularity
The modularity that Starlette is designed on promotes building re-usable
components that can be shared between any ASGI framework. This should enable
an ecosystem of shared middleware and mountable applications.
The clean API separation also means it's easier to understand each component
in isolation.
---
Starlette is BSD licensed code. Designed & crafted with care.— ⭐️ —
[asgi]: https://asgi.readthedocs.io/en/latest/
[httpx]: https://www.python-httpx.org/
[jinja2]: https://jinja.palletsprojects.com/
[python-multipart]: https://multipart.fastapiexpert.com/
[itsdangerous]: https://itsdangerous.palletsprojects.com/
[sqlalchemy]: https://www.sqlalchemy.org
[pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation
[techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf
================================================
FILE: docs/js/custom.js
================================================
function shuffle(array) {
var currentIndex = array.length, temporaryValue, randomIndex;
while (0 !== currentIndex) {
randomIndex = Math.floor(Math.random() * currentIndex);
currentIndex -= 1;
temporaryValue = array[currentIndex];
array[currentIndex] = array[randomIndex];
array[randomIndex] = temporaryValue;
}
return array;
}
async function showRandomAnnouncement(groupId, timeInterval) {
const announceGroup = document.getElementById(groupId);
if (announceGroup) {
let children = [].slice.call(announceGroup.children);
children = shuffle(children)
let index = 0
const announceRandom = () => {
children.forEach((el, i) => { el.style.display = "none" });
children[index].style.display = "block"
index = (index + 1) % children.length
}
announceRandom()
setInterval(announceRandom, timeInterval)
}
}
async function main() {
showRandomAnnouncement('announce-msg', 5000)
}
document$.subscribe(() => {
main()
})
================================================
FILE: docs/lifespan.md
================================================
Starlette applications can register a lifespan handler for dealing with
code that needs to run before the application starts up, or when the application
is shutting down.
```python
import contextlib
from starlette.applications import Starlette
@contextlib.asynccontextmanager
async def lifespan(app):
async with some_async_resource():
print("Run at startup!")
yield
print("Run on shutdown!")
routes = [
...
]
app = Starlette(routes=routes, lifespan=lifespan)
```
Starlette will not start serving any incoming requests until the lifespan has been run.
The lifespan teardown will run once all connections have been closed, and
any in-process background tasks have completed.
Consider using [`anyio.create_task_group()`](https://anyio.readthedocs.io/en/stable/tasks.html)
for managing asynchronous tasks.
## Lifespan State
The lifespan has the concept of `state`, which is a dictionary that
can be used to share the objects between the lifespan, and the requests.
```python
import contextlib
from typing import AsyncIterator, TypedDict
import httpx
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
class State(TypedDict):
http_client: httpx.AsyncClient
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[State]:
async with httpx.AsyncClient() as client:
yield {"http_client": client}
async def homepage(request: Request) -> PlainTextResponse:
client = request.state.http_client
response = await client.get("https://www.example.com")
return PlainTextResponse(response.text)
app = Starlette(
lifespan=lifespan,
routes=[Route("/", homepage)]
)
```
The `state` received on the requests is a **shallow** copy of the state received on the
lifespan handler.
## Accessing State
The state can be accessed using either attribute-style or dictionary-style syntax.
The dictionary-style syntax was introduced in Starlette 0.52.0 (January 2026), with the idea of
improving type safety when using the lifespan state, given that `Request` became a generic over
the state type.
```python
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import TypedDict
import httpx
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
class State(TypedDict):
http_client: httpx.AsyncClient
@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[State]:
async with httpx.AsyncClient() as client:
yield {"http_client": client}
async def homepage(request: Request[State]) -> PlainTextResponse:
client = request.state["http_client"]
reveal_type(client) # Revealed type is 'httpx.AsyncClient'
response = await client.get("https://www.example.com")
return PlainTextResponse(response.text)
app = Starlette(lifespan=lifespan, routes=[Route("/", homepage)])
```
This also works with WebSockets:
```python
async def websocket_endpoint(websocket: WebSocket[State]) -> None:
await websocket.accept()
client = websocket.state["http_client"]
response = await client.get("https://www.example.com")
await websocket.send_text(response.text)
await websocket.close()
app = Starlette(lifespan=lifespan, routes=[WebSocketRoute("/ws", websocket_endpoint)])
```
!!! note
There were many attempts to make this work with attribute-style access instead of
dictionary-style access, but none were satisfactory, given they would have been
breaking changes, or there were typing limitations.
For more details, see:
- [@Kludex/starlette#issues/3005](https://github.com/Kludex/starlette/issues/3005)
- [@python/typing#discussions/1457](https://github.com/python/typing/discussions/1457)
- [@Kludex/starlette#pull/3036](https://github.com/Kludex/starlette/pull/3036)
## Running lifespan in tests
You should use `TestClient` as a context manager, to ensure that the lifespan is called.
```python
from example import app
from starlette.testclient import TestClient
def test_homepage():
with TestClient(app) as client:
# Application's lifespan is called on entering the block.
response = client.get("/")
assert response.status_code == 200
# And the lifespan's teardown is run when exiting the block.
```
================================================
FILE: docs/middleware.md
================================================
Starlette includes several middleware classes for adding behavior that is applied across
your entire application. These are all implemented as standard ASGI
middleware classes, and can be applied either to Starlette or to any other ASGI application.
## Using middleware
The Starlette application class allows you to include the ASGI middleware
in a way that ensures that it remains wrapped by the exception handler.
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
routes = ...
# Ensure that all requests include an 'example.com' or
# '*.example.com' host header, and strictly enforce https-only access.
middleware = [
Middleware(
TrustedHostMiddleware,
allowed_hosts=['example.com', '*.example.com'],
),
Middleware(HTTPSRedirectMiddleware)
]
app = Starlette(routes=routes, middleware=middleware)
```
Every Starlette application automatically includes two pieces of middleware by default:
* `ServerErrorMiddleware` - Ensures that application exceptions may return a custom 500 page, or display an application traceback in DEBUG mode. This is *always* the outermost middleware layer.
* `ExceptionMiddleware` - Adds exception handlers, so that particular types of expected exception cases can be associated with handler functions. For example raising `HTTPException(status_code=404)` within an endpoint will end up rendering a custom 404 page.
Middleware is evaluated from top-to-bottom, so the flow of execution in our example
application would look like this:
* Middleware
* `ServerErrorMiddleware`
* `TrustedHostMiddleware`
* `HTTPSRedirectMiddleware`
* `ExceptionMiddleware`
* Routing
* Endpoint
The following middleware implementations are available in the Starlette package:
## CORSMiddleware
Adds appropriate [CORS headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) to outgoing responses in order to allow cross-origin requests from browsers.
The default parameters used by the CORSMiddleware implementation are restrictive by default,
so you'll need to explicitly enable particular origins, methods, or headers, in order
for browsers to be permitted to use them in a Cross-Domain context.
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
routes = ...
middleware = [
Middleware(CORSMiddleware, allow_origins=['*'])
]
app = Starlette(routes=routes, middleware=middleware)
```
The following arguments are supported:
* `allow_origins` - A list of origins that should be permitted to make cross-origin requests. eg. `['https://example.org', 'https://www.example.org']`. You can use `['*']` to allow any origin.
* `allow_origin_regex` - A regex string to match against origins that should be permitted to make cross-origin requests. eg. `'https://[a-zA-Z0-9-]+\.example\.org'`. Avoid using `.*` or `.+` as they match URL special characters (`/`, `@`, `#`, `?`) and may result in overly permissive origin matching. Use specific character classes like `[a-zA-Z0-9-]+` instead.
* `allow_methods` - A list of HTTP methods that should be allowed for cross-origin requests. Defaults to `['GET']`. You can use `['*']` to allow all standard methods.
* `allow_headers` - A list of HTTP request headers that should be supported for cross-origin requests. Defaults to `[]`. You can use `['*']` to allow all headers. The `Accept`, `Accept-Language`, `Content-Language` and `Content-Type` headers are always allowed for CORS requests.
* `allow_credentials` - Indicate that cookies should be supported for cross-origin requests. Defaults to `False`. Also, `allow_origins`, `allow_methods` and `allow_headers` cannot be set to `['*']` for credentials to be allowed, all of them must be explicitly specified.
* `allow_private_network` - Indicates whether to accept cross-origin requests over a private network. Defaults to `False`.
* `expose_headers` - Indicate any response headers that should be made accessible to the browser. Defaults to `[]`.
* `max_age` - Sets a maximum time in seconds for browsers to cache CORS responses. Defaults to `600`.
The middleware responds to two particular types of HTTP request...
#### CORS preflight requests
These are any `OPTIONS` request with `Origin` and `Access-Control-Request-Method` headers.
In this case the middleware will intercept the incoming request and respond with
appropriate CORS headers, and either a 200 or 400 response for informational purposes.
#### Simple requests
Any request with an `Origin` header. In this case the middleware will pass the
request through as normal, but will include appropriate CORS headers on the response.
#### Private Network Access (PNA)
Private Network Access is a browser security feature that restricts websites from public networks from accessing servers on private networks.
When a website attempts to make such a cross-network request, the browser will send a `Access-Control-Request-Private-Network: true` header in the
pre-flight request. If the `allow_private_network` flag is set to `True`, the middleware will include the `Access-Control-Allow-Private-Network: true`
header in the response, allowing the request. If set to `False`, the middleware will return a 400 response, blocking the request.
### CORSMiddleware Global Enforcement
When using CORSMiddleware with your Starlette application, it's important to ensure that CORS headers are applied even to error responses generated by unhandled exceptions. The recommended solution is to wrap the entire Starlette application with CORSMiddleware. This approach guarantees that even if an exception is caught by ServerErrorMiddleware (or other outer error-handling middleware), the response will still include the proper `Access-Control-Allow-Origin` header.
For example, instead of adding CORSMiddleware as an inner `middleware` via the Starlette middleware parameter, you can wrap your application as follows:
```python
from starlette.applications import Starlette
from starlette.middleware.cors import CORSMiddleware
import uvicorn
app = Starlette()
app = CORSMiddleware(app=app, allow_origins=["*"])
# ... your routes and middleware configuration ...
if __name__ == '__main__':
uvicorn.run(
app,
host='0.0.0.0',
port=8000
)
```
## SessionMiddleware
Adds signed cookie-based HTTP sessions. Session information is readable but not modifiable.
The session cookie is always set with the `"HttpOnly"` flag, preventing client-side JavaScript from accessing it.
Access or modify the session data using the `request.session` dictionary interface.
The following arguments are supported:
* `secret_key` - Should be a random string.
* `session_cookie` - Defaults to "session".
* `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session.
* `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`.
* `path` - The path set for the session cookie. Defaults to `'/'`.
* `https_only` - Indicate that the `"Secure"` flag should be set (can be used with HTTPS only). Defaults to `False`. Set this to `True` in production to ensure the session cookie is only sent over HTTPS.
* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains ([reference](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute)).
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
routes = ...
middleware = [
Middleware(SessionMiddleware, secret_key=..., https_only=True)
]
app = Starlette(routes=routes, middleware=middleware)
```
## HTTPSRedirectMiddleware
Enforces that all incoming requests must either be `https` or `wss`. Any incoming
requests to `http` or `ws` will be redirected to the secure scheme instead.
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
routes = ...
middleware = [
Middleware(HTTPSRedirectMiddleware)
]
app = Starlette(routes=routes, middleware=middleware)
```
There are no configuration options for this middleware class.
## TrustedHostMiddleware
Enforces that all incoming requests have a correctly set `Host` header, in order
to guard against HTTP Host Header attacks.
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
routes = ...
middleware = [
Middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com'])
]
app = Starlette(routes=routes, middleware=middleware)
```
The following arguments are supported:
* `allowed_hosts` - A list of domain names that should be allowed as hostnames. Wildcard
domains such as `*.example.com` are supported for matching subdomains. To allow any
hostname either use `allowed_hosts=["*"]` or omit the middleware.
* `www_redirect` - If set to True, requests to non-www versions of the allowed hosts will be redirected to their www counterparts. Defaults to `True`.
If an incoming request does not validate correctly then a 400 response will be sent.
## GZipMiddleware
Handles GZip responses for any request that includes `"gzip"` in the `Accept-Encoding` header.
The middleware will handle both standard and streaming responses.
??? info "Buffer on streaming responses"
On streaming responses, the middleware will buffer the response before compressing it.
The idea is that we don't want to compress every small chunk of data, as it would be inefficient.
Instead, we buffer the response until it reaches a certain size, and then compress it.
This may cause a delay in the response, as the middleware waits for the buffer to fill up before compressing it.
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
routes = ...
middleware = [
Middleware(GZipMiddleware, minimum_size=1000, compresslevel=9)
]
app = Starlette(routes=routes, middleware=middleware)
```
The following arguments are supported:
* `minimum_size` - Do not GZip responses that are smaller than this minimum size in bytes. Defaults to `500`.
* `compresslevel` - Used during GZip compression. It is an integer ranging from 1 to 9. Defaults to `9`. Lower value results in faster compression but larger file sizes, while higher value results in slower compression but smaller file sizes.
The middleware won't GZip responses that already have either a `Content-Encoding` set, to prevent them from
being encoded twice, or a `Content-Type` set to `text/event-stream`, to avoid compressing server-sent events.
## BaseHTTPMiddleware
An abstract class that allows you to write ASGI middleware against a request/response
interface.
### Usage
To implement a middleware class using `BaseHTTPMiddleware`, you must override the
`async def dispatch(request, call_next)` method.
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
class CustomHeaderMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers['Custom'] = 'Example'
return response
routes = ...
middleware = [
Middleware(CustomHeaderMiddleware)
]
app = Starlette(routes=routes, middleware=middleware)
```
If you want to provide configuration options to the middleware class you should
override the `__init__` method, ensuring that the first argument is `app`, and
any remaining arguments are optional keyword arguments. Make sure to set the `app`
attribute on the instance if you do this.
```python
class CustomHeaderMiddleware(BaseHTTPMiddleware):
def __init__(self, app, header_value='Example'):
super().__init__(app)
self.header_value = header_value
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers['Custom'] = self.header_value
return response
middleware = [
Middleware(CustomHeaderMiddleware, header_value='Customized')
]
app = Starlette(routes=routes, middleware=middleware)
```
Middleware classes should not modify their state outside of the `__init__` method.
Instead you should keep any state local to the `dispatch` method, or pass it
around explicitly, rather than mutating the middleware instance.
### Limitations
Currently, the `BaseHTTPMiddleware` has some known limitations:
- Using `BaseHTTPMiddleware` will prevent changes to [`contextvars.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/Kludex/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior). Importantly, this also means that if a `BaseHTTPMiddleware` is positioned earlier in the middleware stack, it will disrupt `contextvars` propagation for any subsequent Pure ASGI Middleware that relies on them.
To overcome these limitations, use [pure ASGI middleware](#pure-asgi-middleware), as shown below.
## Pure ASGI Middleware
The [ASGI spec](https://asgi.readthedocs.io/en/latest/) makes it possible to implement ASGI middleware using the ASGI interface directly, as a chain of ASGI applications that call into the next one. In fact, this is how middleware classes shipped with Starlette are implemented.
This lower-level approach provides greater control over behavior and enhanced interoperability across frameworks and servers. It also overcomes the [limitations of `BaseHTTPMiddleware`](#limitations).
### Writing pure ASGI middleware
The most common way to create an ASGI middleware is with a class.
```python
class ASGIMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
await self.app(scope, receive, send)
```
The middleware above is the most basic ASGI middleware. It receives a parent ASGI application as an argument for its constructor, and implements an `async __call__` method which calls into that parent application.
Some implementations such as [`asgi-cors`](https://github.com/simonw/asgi-cors/blob/10ef64bfcc6cd8d16f3014077f20a0fb8544ec39/asgi_cors.py) use an alternative style, using functions:
```python
import functools
def asgi_middleware():
def asgi_decorator(app):
@functools.wraps(app)
async def wrapped_app(scope, receive, send):
await app(scope, receive, send)
return wrapped_app
return asgi_decorator
```
In any case, ASGI middleware must be callables that accept three arguments: `scope`, `receive`, and `send`.
* `scope` is a dict holding information about the connection, where `scope["type"]` may be:
* [`"http"`](https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope): for HTTP requests.
* [`"websocket"`](https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope): for WebSocket connections.
* [`"lifespan"`](https://asgi.readthedocs.io/en/latest/specs/lifespan.html#scope): for ASGI lifespan messages.
* `receive` and `send` can be used to exchange ASGI event messages with the ASGI server — more on this below. The type and contents of these messages depend on the scope type. Learn more in the [ASGI specification](https://asgi.readthedocs.io/en/latest/specs/index.html).
### Using pure ASGI middleware
Pure ASGI middleware can be used like any other middleware:
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from .middleware import ASGIMiddleware
routes = ...
middleware = [
Middleware(ASGIMiddleware),
]
app = Starlette(..., middleware=middleware)
```
See also [Using middleware](#using-middleware).
### Type annotations
There are two ways of annotating a middleware: using Starlette itself or [`asgiref`](https://github.com/django/asgiref).
* Using Starlette: for most common use cases.
```python
from starlette.types import ASGIApp, Message, Scope, Receive, Send
class ASGIMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)
async def send_wrapper(message: Message) -> None:
# ... Do something
await send(message)
await self.app(scope, receive, send_wrapper)
```
* Using [`asgiref`](https://github.com/django/asgiref): for more rigorous type hinting.
```python
from asgiref.typing import ASGI3Application, ASGIReceiveCallable, ASGISendCallable, Scope
from asgiref.typing import ASGIReceiveEvent, ASGISendEvent
class ASGIMiddleware:
def __init__(self, app: ASGI3Application) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
async def send_wrapper(message: ASGISendEvent) -> None:
# ... Do something
await send(message)
return await self.app(scope, receive, send_wrapper)
```
### Common patterns
#### Processing certain requests only
ASGI middleware can apply specific behavior according to the contents of `scope`.
For example, to only process HTTP requests, write this...
```python
class ASGIMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
... # Do something here!
await self.app(scope, receive, send)
```
Likewise, WebSocket-only middleware would guard on `scope["type"] != "websocket"`.
The middleware may also act differently based on the request method, URL, headers, etc.
#### Reusing Starlette components
Starlette provides several data structures that accept the ASGI `scope`, `receive` and/or `send` arguments, allowing you to work at a higher level of abstraction. Such data structures include [`Request`](requests.md#request), [`Headers`](requests.md#headers), [`QueryParams`](requests.md#query-parameters), [`URL`](requests.md#url), etc.
For example, you can instantiate a `Request` to more easily inspect an HTTP request:
```python
from starlette.requests import Request
class ASGIMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
request = Request(scope)
... # Use `request.method`, `request.url`, `request.headers`, etc.
await self.app(scope, receive, send)
```
You can also reuse [responses](responses.md), which are ASGI applications as well.
#### Sending eager responses
Inspecting the connection `scope` allows you to conditionally call into a different ASGI app. One use case might be sending a response without calling into the app.
As an example, this middleware uses a dictionary to perform permanent redirects based on the requested path. This could be used to implement ongoing support of legacy URLs in case you need to refactor route URL patterns.
```python
from starlette.datastructures import URL
from starlette.responses import RedirectResponse
class RedirectsMiddleware:
def __init__(self, app, path_mapping: dict):
self.app = app
self.path_mapping = path_mapping
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
url = URL(scope=scope)
if url.path in self.path_mapping:
url = url.replace(path=self.path_mapping[url.path])
response = RedirectResponse(url, status_code=301)
await response(scope, receive, send)
return
await self.app(scope, receive, send)
```
Example usage would look like this:
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
routes = ...
redirections = {
"/v1/resource/": "/v2/resource/",
# ...
}
middleware = [
Middleware(RedirectsMiddleware, path_mapping=redirections),
]
app = Starlette(routes=routes, middleware=middleware)
```
#### Inspecting or modifying the request
Request information can be accessed or changed by manipulating the `scope`. For a full example of this pattern, see Uvicorn's [`ProxyHeadersMiddleware`](https://github.com/encode/uvicorn/blob/fd4386fefb8fe8a4568831a7d8b2930d5fb61455/uvicorn/middleware/proxy_headers.py) which inspects and tweaks the `scope` when serving behind a frontend proxy.
Besides, wrapping the `receive` ASGI callable allows you to access or modify the HTTP request body by manipulating [`http.request`](https://asgi.readthedocs.io/en/latest/specs/www.html#request-receive-event) ASGI event messages.
As an example, this middleware computes and logs the size of the incoming request body...
```python
class LoggedRequestBodySizeMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
body_size = 0
async def receive_logging_request_body_size():
nonlocal body_size
message = await receive()
assert message["type"] == "http.request"
body_size += len(message.get("body", b""))
if not message.get("more_body", False):
print(f"Size of request body was: {body_size} bytes")
return message
await self.app(scope, receive_logging_request_body_size, send)
```
Likewise, WebSocket middleware may manipulate [`websocket.receive`](https://asgi.readthedocs.io/en/latest/specs/www.html#receive-receive-event) ASGI event messages to inspect or alter incoming WebSocket data.
For an example that changes the HTTP request body, see [`msgpack-asgi`](https://github.com/florimondmanca/msgpack-asgi).
#### Inspecting or modifying the response
Wrapping the `send` ASGI callable allows you to inspect or modify the HTTP response sent by the underlying application. To do so, react to [`http.response.start`](https://asgi.readthedocs.io/en/latest/specs/www.html#response-start-send-event) or [`http.response.body`](https://asgi.readthedocs.io/en/latest/specs/www.html#response-body-send-event) ASGI event messages.
As an example, this middleware adds some fixed extra response headers:
```python
from starlette.datastructures import MutableHeaders
class ExtraResponseHeadersMiddleware:
def __init__(self, app, headers):
self.app = app
self.headers = headers
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
return await self.app(scope, receive, send)
async def send_with_extra_headers(message):
if message["type"] == "http.response.start":
headers = MutableHeaders(scope=message)
for key, value in self.headers:
headers.append(key, value)
await send(message)
await self.app(scope, receive, send_with_extra_headers)
```
See also [`asgi-logger`](https://github.com/Kludex/asgi-logger/blob/main/asgi_logger/middleware.py) for an example that inspects the HTTP response and logs a configurable HTTP access log line.
Likewise, WebSocket middleware may manipulate [`websocket.send`](https://asgi.readthedocs.io/en/latest/specs/www.html#send-send-event) ASGI event messages to inspect or alter outgoing WebSocket data.
Note that if you change the response body, you will need to update the response `Content-Length` header to match the new response body length. See [`brotli-asgi`](https://github.com/fullonic/brotli-asgi) for a complete example.
#### Passing information to endpoints
If you need to share information with the underlying app or endpoints, you may store it into the `scope` dictionary. Note that this is a convention -- for example, Starlette uses this to share routing information with endpoints -- but it is not part of the ASGI specification. If you do so, be sure to avoid conflicts by using keys that have low chances of being used by other middleware or applications.
For example, when including the middleware below, endpoints would be able to access `request.scope["asgi_transaction_id"]`.
```python
import uuid
class TransactionIDMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
scope["asgi_transaction_id"] = uuid.uuid4()
await self.app(scope, receive, send)
```
#### Cleanup and error handling
You can wrap the application in a `try/except/finally` block or a context manager to perform cleanup operations or do error handling.
For example, the following middleware might collect metrics and process application exceptions...
```python
import time
class MonitoringMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
start = time.time()
try:
await self.app(scope, receive, send)
except Exception as exc:
... # Process the exception
raise
finally:
end = time.time()
elapsed = end - start
... # Submit `elapsed` as a metric to a monitoring backend
```
See also [`timing-asgi`](https://github.com/steinnes/timing-asgi) for a full example of this pattern.
### Gotchas
#### ASGI middleware should be stateless
Because ASGI is designed to handle concurrent requests, any connection-specific state should be scoped to the `__call__` implementation. Not doing so would typically lead to conflicting variable reads/writes across requests, and most likely bugs.
As an example, this would conditionally replace the response body, if an `X-Mock` header is present in the response...
=== "✅ Do"
```python
from starlette.datastructures import Headers
class MockResponseBodyMiddleware:
def __init__(self, app, content):
self.app = app
self.content = content
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# A flag that we will turn `True` if the HTTP response
# has the 'X-Mock' header.
# ✅: Scoped to this function.
should_mock = False
async def maybe_send_with_mock_content(message):
nonlocal should_mock
if message["type"] == "http.response.start":
headers = Headers(raw=message["headers"])
should_mock = headers.get("X-Mock") == "1"
await send(message)
elif message["type"] == "http.response.body":
if should_mock:
message = {"type": "http.response.body", "body": self.content}
await send(message)
await self.app(scope, receive, maybe_send_with_mock_content)
```
=== "❌ Don't"
```python hl_lines="7-8"
from starlette.datastructures import Headers
class MockResponseBodyMiddleware:
def __init__(self, app, content):
self.app = app
self.content = content
# ❌: This variable would be read and written across requests!
self.should_mock = False
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
async def maybe_send_with_mock_content(message):
if message["type"] == "http.response.start":
headers = Headers(raw=message["headers"])
self.should_mock = headers.get("X-Mock") == "1"
await send(message)
elif message["type"] == "http.response.body":
if self.should_mock:
message = {"type": "http.response.body", "body": self.content}
await send(message)
await self.app(scope, receive, maybe_send_with_mock_content)
```
See also [`GZipMiddleware`](https://github.com/Kludex/starlette/blob/9ef1b91c9c043197da6c3f38aa153fd874b95527/starlette/middleware/gzip.py) for a full example implementation that navigates this potential gotcha.
### Further reading
This documentation should be enough to have a good basis on how to create an ASGI middleware.
Nonetheless, there are great articles about the subject:
- [Introduction to ASGI: Emergence of an Async Python Web Ecosystem](https://florimond.dev/en/posts/2019/08/introduction-to-asgi-async-python-web/)
- [How to write ASGI middleware](https://pgjones.dev/blog/how-to-write-asgi-middleware-2021/)
## Using middleware in other frameworks
To wrap ASGI middleware around other ASGI applications, you should use the
more general pattern of wrapping the application instance:
```python
app = TrustedHostMiddleware(app, allowed_hosts=['example.com'])
```
You can do this with a Starlette application instance too, but it is preferable
to use the `middleware=` style, as it will:
* Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`.
* Preserves the top-level `app` instance.
## Applying middleware to groups of routes
Middleware can also be added to `Mount` instances, which allows you to apply middleware to a group of routes or a sub-application:
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.routing import Mount, Route
routes = [
Mount(
"/",
routes=[
Route(
"/example",
endpoint=...,
)
],
middleware=[Middleware(GZipMiddleware)]
)
]
app = Starlette(routes=routes)
```
Note that middleware used in this way is *not* wrapped in exception handling middleware like the middleware applied to the `Starlette` application is.
This is often not a problem because it only applies to middleware that inspect or modify the `Response`, and even then you probably don't want to apply this logic to error responses.
If you do want to apply the middleware logic to error responses only on some routes you have a couple of options:
* Add an `ExceptionMiddleware` onto the `Mount`
* Add a `try/except` block to your middleware and return an error response from there
* Split up marking and processing into two middlewares, one that gets put on `Mount` which marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting.
The `Route`/`WebSocket` class also accepts a `middleware` argument, which allows you to apply middleware to a single route:
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.routing import Route
routes = [
Route(
"/example",
endpoint=...,
middleware=[Middleware(GZipMiddleware)]
)
]
app = Starlette(routes=routes)
```
You can also apply middleware to the `Router` class, which allows you to apply middleware to a group of routes:
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.routing import Route, Router
routes = [
Route("/example", endpoint=...),
Route("/another", endpoint=...),
]
router = Router(routes=routes, middleware=[Middleware(GZipMiddleware)])
```
## Third party middleware
#### [asgi-auth-github](https://github.com/simonw/asgi-auth-github)
This middleware adds authentication to any ASGI application, requiring users to sign in
using their GitHub account (via [OAuth](https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps/)).
Access can be restricted to specific users or to members of specific GitHub organizations or teams.
#### [asgi-csrf](https://github.com/simonw/asgi-csrf)
Middleware for protecting against CSRF attacks. This middleware implements the Double Submit Cookie pattern, where a cookie is set, then it is compared to a csrftoken hidden form field or an `x-csrftoken` HTTP header.
#### [AuthlibMiddleware](https://github.com/aogier/starlette-authlib)
A drop-in replacement for Starlette session middleware, using [authlib's jwt](https://docs.authlib.org/en/latest/jose/jwt.html)
module.
#### [BugsnagMiddleware](https://github.com/ashinabraham/starlette-bugsnag)
A middleware class for logging exceptions to [Bugsnag](https://www.bugsnag.com/).
#### [CSRFMiddleware](https://github.com/frankie567/starlette-csrf)
Middleware for protecting against CSRF attacks. This middleware implements the Double Submit Cookie pattern, where a cookie is set, then it is compared to an `x-csrftoken` HTTP header.
#### [EarlyDataMiddleware](https://github.com/HarrySky/starlette-early-data)
Middleware and decorator for detecting and denying [TLSv1.3 early data](https://tools.ietf.org/html/rfc8470) requests.
#### [PrometheusMiddleware](https://github.com/perdy/starlette-prometheus)
A middleware class for capturing Prometheus metrics related to requests and responses, including in progress requests, timing...
#### [ProxyHeadersMiddleware](https://github.com/encode/uvicorn/blob/main/uvicorn/middleware/proxy_headers.py)
Uvicorn includes a middleware class for determining the client IP address,
when proxy servers are being used, based on the `X-Forwarded-Proto` and `X-Forwarded-For` headers. For more complex proxy configurations, you might want to adapt this middleware.
#### [RateLimitMiddleware](https://github.com/abersheeran/asgi-ratelimit)
A rate limit middleware. Regular expression matches url; flexible rules; highly customizable. Very easy to use.
#### [RequestIdMiddleware](https://github.com/snok/asgi-correlation-id)
A middleware class for reading/generating request IDs and attaching them to application logs.
#### [RollbarMiddleware](https://docs.rollbar.com/docs/starlette)
A middleware class for logging exceptions, errors, and log messages to [Rollbar](https://www.rollbar.com).
#### [StarletteOpentracing](https://github.com/acidjunk/starlette-opentracing)
A middleware class that emits tracing info to [OpenTracing.io](https://opentracing.io/) compatible tracers and
can be used to profile and monitor distributed applications.
#### [SecureCookiesMiddleware](https://github.com/thearchitector/starlette-securecookies)
Customizable middleware for adding automatic cookie encryption and decryption to Starlette applications, with
extra support for existing cookie-based middleware.
#### [TimingMiddleware](https://github.com/steinnes/timing-asgi)
A middleware class to emit timing information (cpu and wall time) for each request which
passes through it. Includes examples for how to emit these timings as statsd metrics.
#### [WSGIMiddleware](https://github.com/abersheeran/a2wsgi)
A middleware class in charge of converting a WSGI application into an ASGI one.
================================================
FILE: docs/overrides/main.html
================================================
{% extends "base.html" %}
{% block extrahead %}
{{ super() }}
{% endblock %}
{% block announce %}
{% endblock %}
================================================
FILE: docs/overrides/partials/toc-item.html
================================================
{{ toc_item.title }}
{% if toc_item.children %}
{% endif %}
================================================
FILE: docs/release-notes.md
================================================
---
toc_depth: 2
---
## 1.0.0rc1 (February 23, 2026)
We're ready! I'm thrilled to announce the first release candidate for Starlette 1.0.
Starlette was created in June 2018 by Tom Christie, and has been on ZeroVer for years. Today, it's downloaded
almost [10 million times a day](https://pypistats.org/packages/starlette), serves as the foundation for FastAPI,
and has inspired many other frameworks. In the age of AI, Starlette continues to play an important role as a
dependency of the Python MCP SDK.
This release focuses on removing deprecated features that were marked for removal in 1.0.0, along with some
last minute bug fixes. It's a release candidate, so we can gather feedback from the community before the final
1.0.0 release soon.
A huge thank you to all the contributors who have helped make Starlette what it is today.
In particular, I'd like to recognize:
* [Kim Christie](https://github.com/lovelydinosaur) - The original creator of Starlette, Uvicorn, and MkDocs, and the
current maintainer of HTTPX. Kim's work helped lay the foundation for the modern async Python ecosystem.
* [Adrian Garcia Badaracco](https://github.com/adriangb) - One of the smartest people I know, whom I have the pleasure of working with at Pydantic.
* [Thomas Grainger](https://github.com/graingert) - My async teacher, always ready to help with questions.
* [Alex Grönholm](https://github.com/agronholm) - Another async mentor, always prompt to help with questions.
* [Florimond Manca](https://github.com/florimondmanca) - Always present in the early days of both Starlette and Uvicorn, and helped a lot in the ecosystem.
* [Amin Alaee](https://github.com/aminalaee) - Contributed a lot with file-related PRs.
* [Sebastián Ramírez](https://github.com/tiangolo) - Maintains FastAPI upstream, and always in contact to help with upstream issues.
* [Alex Oleshkevich](https://github.com/alex-oleshkevich) - Helped a lot on templates and many discussions.
* [abersheeran](https://github.com/abersheeran) - My go-to person when I need help on many subjects.
I'd also like to thank my sponsors for their support. A special thanks to
[@tiangolo](https://github.com/tiangolo), [@huggingface](https://github.com/huggingface),
and [@elevenlabs](https://github.com/elevenlabs) for their generous sponsorship, and to all my other sponsors:
[@roboflow](https://github.com/roboflow),
[@ogabrielluiz](https://github.com/ogabrielluiz),
[@SaboniAmine](https://github.com/SaboniAmine),
[@russbiggs](https://github.com/russbiggs),
[@BryceBeagle](https://github.com/BryceBeagle),
[@chdsbd](https://github.com/chdsbd),
[@TheR1D](https://github.com/TheR1D),
[@ddanier](https://github.com/ddanier),
[@larsyngvelundin](https://github.com/larsyngvelundin),
[@jpizquierdo](https://github.com/jpizquierdo),
[@alixlahuec](https://github.com/alixlahuec),
[@nathanchapman](https://github.com/nathanchapman),
[@devid8642](https://github.com/devid8642),
[@comet-ml](https://github.com/comet-ml),
[@Evil0ctal](https://github.com/Evil0ctal),
[@msehnout](https://github.com/msehnout),
and [@codingjoe](https://github.com/codingjoe).
#### Removed
* Remove `on_startup` and `on_shutdown` parameters from `Starlette` and `Router`.
Use the `lifespan` parameter instead [#3117](https://github.com/encode/starlette/pull/3117).
* Remove `on_event()` decorator from `Starlette` and `Router`.
Use the `lifespan` parameter instead [#3117](https://github.com/encode/starlette/pull/3117).
* Remove `add_event_handler()` method from `Starlette` and `Router`.
Use the `lifespan` parameter instead [#3117](https://github.com/encode/starlette/pull/3117).
* Remove `startup()` and `shutdown()` methods from `Router`
[#3117](https://github.com/encode/starlette/pull/3117).
* Remove `@app.route()` decorator from `Starlette` and `Router`.
Use `Route` in the `routes` parameter instead [#3117](https://github.com/encode/starlette/pull/3117).
* Remove `@app.websocket_route()` decorator from `Starlette` and `Router`.
Use `WebSocketRoute` in the `routes` parameter instead [#3117](https://github.com/encode/starlette/pull/3117).
* Remove `@app.exception_handler()` decorator from `Starlette`.
Use `exception_handlers` parameter instead [#3117](https://github.com/encode/starlette/pull/3117).
* Remove `@app.middleware()` decorator from `Starlette`.
Use `middleware` parameter instead [#3117](https://github.com/encode/starlette/pull/3117).
* Remove `iscoroutinefunction_or_partial()` from `starlette.routing` [#3117](https://github.com/encode/starlette/pull/3117).
* Remove `**env_options` parameter from `Jinja2Templates`.
Use a preconfigured `jinja2.Environment` via the `env` parameter instead [#3118](https://github.com/encode/starlette/pull/3118).
* Remove deprecated `TemplateResponse(name, context)` signature from `Jinja2Templates`.
Use `TemplateResponse(request, name, ...)` instead [#3118](https://github.com/encode/starlette/pull/3118).
* Remove deprecated `method` parameter from `FileResponse` [#3147](https://github.com/encode/starlette/pull/3147).
#### Added
* Add state generic to `WebSocket` [#3132](https://github.com/encode/starlette/pull/3132).
#### Fixed
* Include `bytes` unit in `Content-Range` header on 416 responses.
* Handle null bytes in `StaticFiles` path [#3139](https://github.com/encode/starlette/pull/3139).
* Use sort-based merge for `Range` header parsing [#3138](https://github.com/encode/starlette/pull/3138).
* Set `Content-Type` instead of `Content-Range` on multi-range responses [#3142](https://github.com/encode/starlette/pull/3142).
* Use CRLF line endings in multipart byterange boundaries [#3143](https://github.com/encode/starlette/pull/3143).
* Avoid mutating `FileResponse` headers on range requests [#3144](https://github.com/encode/starlette/pull/3144).
* Return explicit origin in CORS response when credentials are allowed [#3137](https://github.com/encode/starlette/pull/3137).
* Enable `autoescape` by default in `Jinja2Templates` [#3148](https://github.com/encode/starlette/pull/3148).
#### Changed
* `jinja2` must now be installed to import `Jinja2Templates`. Previously it would only
fail when instantiating the class [#3118](https://github.com/encode/starlette/pull/3118).
## 0.52.1 (January 18, 2026)
#### Fixed
* Only use `typing_extensions` in older Python versions [#3109](https://github.com/Kludex/starlette/pull/3109).
## 0.52.0 (January 18, 2026)
In this release, `State` can be accessed using dictionary-style syntax for improved type
safety ([#3036](https://github.com/Kludex/starlette/pull/3036)).
```python
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import TypedDict
import httpx
from starlette.applications import Starlette
from starlette.requests import Request
class State(TypedDict):
http_client: httpx.AsyncClient
@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[State]:
async with httpx.AsyncClient() as client:
yield {"http_client": client}
async def homepage(request: Request[State]):
client = request.state["http_client"]
# If you run the below line with mypy or pyright, it will reveal the correct type.
reveal_type(client) # Revealed type is 'httpx.AsyncClient'
```
See [Accessing State](lifespan.md#accessing-state) for more details.
## 0.51.0 (January 10, 2026)
#### Added
* Add `allow_private_network` in `CORSMiddleware` [#3065](https://github.com/Kludex/starlette/pull/3065).
#### Changed
* Increase warning stacklevel on `DeprecationWarning` for wsgi module [#3082](https://github.com/Kludex/starlette/pull/3082).
## 0.50.0 (November 1, 2025)
#### Removed
* Drop Python 3.9 support [#3061](https://github.com/Kludex/starlette/pull/3061).
## 0.49.3 (November 1, 2025)
This is the last release that supports Python 3.9, which will be dropped in the next minor release.
#### Fixed
* Relax strictness on `Middleware` type [#3059](https://github.com/Kludex/starlette/pull/3059).
## 0.49.2 (November 1, 2025)
#### Fixed
* Ignore `if-modified-since` header if `if-none-match` is present in `StaticFiles` [#3044](https://github.com/Kludex/starlette/pull/3044).
## 0.49.1 (October 28, 2025)
This release fixes a security vulnerability in the parsing logic of the `Range` header in `FileResponse`.
You can view the full security advisory: [GHSA-7f5h-v6xp-fcq8](https://github.com/Kludex/starlette/security/advisories/GHSA-7f5h-v6xp-fcq8)
#### Fixed
* Optimize the HTTP ranges parsing logic [4ea6e22b489ec388d6004cfbca52dd5b147127c5](https://github.com/Kludex/starlette/commit/4ea6e22b489ec388d6004cfbca52dd5b147127c5)
## 0.49.0 (October 28, 2025)
#### Added
* Add `encoding` parameter to `Config` class [#2996](https://github.com/Kludex/starlette/pull/2996).
* Support multiple cookie headers in `Request.cookies` [#3029](https://github.com/Kludex/starlette/pull/3029).
* Use `Literal` type for `WebSocketEndpoint` encoding values [#3027](https://github.com/Kludex/starlette/pull/3027).
#### Changed
* Do not pollute exception context in `Middleware` when using `BaseHTTPMiddleware` [#2976](https://github.com/Kludex/starlette/pull/2976).
## 0.48.0 (September 13, 2025)
#### Added
* Add official Python 3.14 support [#3013](https://github.com/Kludex/starlette/pull/3013).
#### Changed
* Implement [RFC9110](https://www.rfc-editor.org/rfc/rfc9110) http status names [#2939](https://github.com/Kludex/starlette/pull/2939).
## 0.47.3 (August 24, 2025)
#### Fixed
* Use `asyncio.iscoroutinefunction` for Python 3.12 and older [#2984](https://github.com/Kludex/starlette/pull/2984).
## 0.47.2 (July 20, 2025)
#### Fixed
* Make `UploadFile` check for future rollover [#2962](https://github.com/Kludex/starlette/pull/2962).
## 0.47.1 (June 21, 2025)
#### Fixed
* Use `Self` in `TestClient.__enter__` [#2951](https://github.com/Kludex/starlette/pull/2951).
* Allow async exception handlers to type-check [#2949](https://github.com/Kludex/starlette/pull/2949).
## 0.47.0 (May 29, 2025)
#### Added
* Add support for ASGI `pathsend` extension [#2671](https://github.com/Kludex/starlette/pull/2671).
* Add `partitioned` attribute to `Response.set_cookie` [#2501](https://github.com/Kludex/starlette/pull/2501).
#### Changed
* Change `methods` parameter type from `list[str]` to `Collection[str]`
[#2903](https://github.com/Kludex/starlette/pull/2903).
* Replace `import typing` by `from typing import ...` in the whole codebase
[#2867](https://github.com/Kludex/starlette/pull/2867).
#### Fixed
* Mark `ExceptionMiddleware.http_exception` as async to prevent thread creation
[#2922](https://github.com/Kludex/starlette/pull/2922).
## 0.46.2 (April 13, 2025)
#### Fixed
* Prevents reraising of exception from BaseHttpMiddleware [#2911](https://github.com/Kludex/starlette/pull/2911).
* Use correct index on backwards compatible logic in `TemplateResponse` [#2909](https://github.com/Kludex/starlette/pull/2909).
## 0.46.1 (March 8, 2025)
#### Fixed
* Allow relative directory path when `follow_symlinks=True` [#2896](https://github.com/Kludex/starlette/pull/2896).
## 0.46.0 (February 22, 2025)
#### Added
* `GZipMiddleware`: Make sure `Vary` header is always added if a response can be compressed [#2865](https://github.com/Kludex/starlette/pull/2865).
#### Fixed
* Raise exception from background task on BaseHTTPMiddleware [#2812](https://github.com/Kludex/starlette/pull/2812).
* `GZipMiddleware`: Don't compress on server sent events [#2871](https://github.com/Kludex/starlette/pull/2871).
#### Changed
* `MultiPartParser`: Rename `max_file_size` to `spool_max_size` [#2780](https://github.com/Kludex/starlette/pull/2780).
#### Deprecated
* Add deprecated warning to `TestClient(timeout=...)` [#2840](https://github.com/Kludex/starlette/pull/2840).
## 0.45.3 (January 24, 2025)
#### Fixed
* Turn directory into string on `lookup_path` on commonpath comparison [#2851](https://github.com/Kludex/starlette/pull/2851).
## 0.45.2 (January 4, 2025)
#### Fixed
* Make `create_memory_object_stream` compatible with old anyio versions once again, and bump anyio minimum version to 3.6.2 [#2833](https://github.com/Kludex/starlette/pull/2833).
## 0.45.1 (December 30, 2024)
#### Fixed
* Close `MemoryObjectReceiveStream` left unclosed upon exception in `BaseHTTPMiddleware` children [#2813](https://github.com/Kludex/starlette/pull/2813).
* Collect errors more reliably from the WebSocket logic on the `TestClient` [#2814](https://github.com/Kludex/starlette/pull/2814).
#### Refactor
* Use a pair of memory object streams instead of two queues on the `TestClient` [#2829](https://github.com/Kludex/starlette/pull/2829).
## 0.45.0 (December 29, 2024)
#### Removed
* Drop Python 3.8 support [#2823](https://github.com/Kludex/starlette/pull/2823).
* Remove `ExceptionMiddleware` import proxy from `starlette.exceptions` module [#2826](https://github.com/Kludex/starlette/pull/2826).
* Remove deprecated `WS_1004_NO_STATUS_RCVD` and `WS_1005_ABNORMAL_CLOSURE` [#2827](https://github.com/Kludex/starlette/pull/2827).
## 0.44.0 (December 28, 2024)
#### Added
* Add `client` parameter to `TestClient` [#2810](https://github.com/Kludex/starlette/pull/2810).
* Add `max_part_size` parameter to `Request.form()` [#2815](https://github.com/Kludex/starlette/pull/2815).
## 0.43.0 (December 25, 2024)
#### Removed
* Remove deprecated `allow_redirects` argument from `TestClient` [#2808](https://github.com/Kludex/starlette/pull/2808).
#### Added
* Make UUID path parameter conversion more flexible [#2806](https://github.com/Kludex/starlette/pull/2806).
## 0.42.0 (December 14, 2024)
#### Added
* Raise `ClientDisconnect` on `StreamingResponse` [#2732](https://github.com/Kludex/starlette/pull/2732).
#### Fixed
* Use ETag from headers when parsing If-Range in FileResponse [#2761](https://github.com/Kludex/starlette/pull/2761).
* Follow directory symlinks in `StaticFiles` when `follow_symlinks=True` [#2711](https://github.com/Kludex/starlette/pull/2711).
* Bump minimum `python-multipart` version to `0.0.18` [0ba8395](https://github.com/Kludex/starlette/commit/0ba83959e609bbd460966f092287df1bbd564cc6).
* Bump minimum `httpx` version to `0.27.0` [#2773](https://github.com/Kludex/starlette/pull/2773).
## 0.41.3 (November 18, 2024)
#### Fixed
* Exclude the query parameters from the `scope[raw_path]` on the `TestClient` [#2716](https://github.com/Kludex/starlette/pull/2716).
* Replace `dict` by `Mapping` on `HTTPException.headers` [#2749](https://github.com/Kludex/starlette/pull/2749).
* Correct middleware argument passing and improve factory pattern [#2752](https://github.com/Kludex/starlette/pull/2752).
## 0.41.2 (October 27, 2024)
#### Fixed
* Revert bump on `python-multipart` on `starlette[full]` extras [#2737](https://github.com/Kludex/starlette/pull/2737).
## 0.41.1 (October 24, 2024)
#### Fixed
* Bump minimum `python-multipart` version to `0.0.13` [#2734](https://github.com/Kludex/starlette/pull/2734).
* Change `python-multipart` import to `python_multipart` [#2733](https://github.com/Kludex/starlette/pull/2733).
## 0.41.0 (October 15, 2024)
#### Added
- Allow to raise `HTTPException` before `websocket.accept()` [#2725](https://github.com/Kludex/starlette/pull/2725).
## 0.40.0 (October 15, 2024)
This release fixes a Denial of service (DoS) via `multipart/form-data` requests.
You can view the full security advisory:
[GHSA-f96h-pmfr-66vw](https://github.com/Kludex/starlette/security/advisories/GHSA-f96h-pmfr-66vw)
#### Fixed
- Add `max_part_size` to `MultiPartParser` to limit the size of parts in `multipart/form-data`
requests [fd038f3](https://github.com/Kludex/starlette/commit/fd038f3070c302bff17ef7d173dbb0b007617733).
## 0.39.2 (September 29, 2024)
#### Fixed
- Allow use of `request.url_for` when only "app" scope is available [#2672](https://github.com/Kludex/starlette/pull/2672).
- Fix internal type hints to support `python-multipart==0.0.12` [#2708](https://github.com/Kludex/starlette/pull/2708).
## 0.39.1 (September 25, 2024)
#### Fixed
- Avoid regex re-compilation in `responses.py` and `schemas.py` [#2700](https://github.com/Kludex/starlette/pull/2700).
- Improve performance of `get_route_path` by removing regular expression usage
[#2701](https://github.com/Kludex/starlette/pull/2701).
- Consider `FileResponse.chunk_size` when handling multiple ranges [#2703](https://github.com/Kludex/starlette/pull/2703).
- Use `token_hex` for generating multipart boundary strings [#2702](https://github.com/Kludex/starlette/pull/2702).
## 0.39.0 (September 23, 2024)
#### Added
* Add support for [HTTP Range](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests) to
`FileResponse` [#2697](https://github.com/Kludex/starlette/pull/2697).
## 0.38.6 (September 22, 2024)
#### Fixed
* Close unclosed `MemoryObjectReceiveStream` in `TestClient` [#2693](https://github.com/Kludex/starlette/pull/2693).
## 0.38.5 (September 7, 2024)
#### Fixed
* Schedule `BackgroundTasks` from within `BaseHTTPMiddleware` [#2688](https://github.com/Kludex/starlette/pull/2688).
This behavior was removed in 0.38.3, and is now restored.
## 0.38.4 (September 1, 2024)
#### Fixed
* Ensure accurate `root_path` removal in `get_route_path` function [#2600](https://github.com/Kludex/starlette/pull/2600).
## 0.38.3 (September 1, 2024)
#### Added
* Support for Python 3.13 [#2662](https://github.com/Kludex/starlette/pull/2662).
#### Fixed
* Don't poll for disconnects in `BaseHTTPMiddleware` via `StreamingResponse` [#2620](https://github.com/Kludex/starlette/pull/2620).
## 0.38.2 (July 27, 2024)
#### Fixed
* Not assume all routines have `__name__` on `routing.get_name()` [#2648](https://github.com/Kludex/starlette/pull/2648).
## 0.38.1 (July 23, 2024)
#### Removed
* Revert "Add support for ASGI pathsend extension" [#2649](https://github.com/Kludex/starlette/pull/2649).
## 0.38.0 (July 20, 2024)
#### Added
* Allow use of `memoryview` in `StreamingResponse` and `Response` [#2576](https://github.com/Kludex/starlette/pull/2576)
and [#2577](https://github.com/Kludex/starlette/pull/2577).
* Send 404 instead of 500 when filename requested is too long on `StaticFiles` [#2583](https://github.com/Kludex/starlette/pull/2583).
#### Changed
* Fail fast on invalid `Jinja2Template` instantiation parameters [#2568](https://github.com/Kludex/starlette/pull/2568).
* Check endpoint handler is async only once [#2536](https://github.com/Kludex/starlette/pull/2536).
#### Fixed
* Add proper synchronization to `WebSocketTestSession` [#2597](https://github.com/Kludex/starlette/pull/2597).
## 0.37.2 (March 5, 2024)
#### Added
* Add `bytes` to `_RequestData` type [#2510](https://github.com/Kludex/starlette/pull/2510).
#### Fixed
* Revert "Turn `scope["client"]` to `None` on `TestClient` (#2377)" [#2525](https://github.com/Kludex/starlette/pull/2525).
* Remove deprecated `app` argument passed to `httpx.Client` on the `TestClient` [#2526](https://github.com/Kludex/starlette/pull/2526).
## 0.37.1 (February 9, 2024)
#### Fixed
* Warn instead of raise for missing env file on `Config` [#2485](https://github.com/Kludex/starlette/pull/2485).
## 0.37.0 (February 5, 2024)
#### Added
* Support the WebSocket Denial Response ASGI extension [#2041](https://github.com/Kludex/starlette/pull/2041).
## 0.36.3 (February 4, 2024)
#### Fixed
* Create `anyio.Event` on async context [#2459](https://github.com/Kludex/starlette/pull/2459).
## 0.36.2 (February 3, 2024)
#### Fixed
* Upgrade `python-multipart` to `0.0.7` [13e5c26](http://github.com/Kludex/starlette/commit/13e5c26a27f4903924624736abd6131b2da80cc5).
* Avoid duplicate charset on `Content-Type` [#2443](https://github.com/Kludex/starlette/2443).
## 0.36.1 (January 23, 2024)
#### Fixed
* Check if "extensions" in scope before checking the extension [#2438](http://github.com/Kludex/starlette/pull/2438).
## 0.36.0 (January 22, 2024)
#### Added
* Add support for ASGI `pathsend` extension [#2435](http://github.com/Kludex/starlette/pull/2435).
* Cancel `WebSocketTestSession` on close [#2427](http://github.com/Kludex/starlette/pull/2427).
* Raise `WebSocketDisconnect` when `WebSocket.send()` excepts `IOError` [#2425](http://github.com/Kludex/starlette/pull/2425).
* Raise `FileNotFoundError` when the `env_file` parameter on `Config` is not valid [#2422](http://github.com/Kludex/starlette/pull/2422).
## 0.35.1 (January 11, 2024)
#### Fixed
* Stop using the deprecated "method" parameter in `FileResponse` inside of `StaticFiles` [#2406](https://github.com/Kludex/starlette/pull/2406).
* Make `typing-extensions` optional again [#2409](https://github.com/Kludex/starlette/pull/2409).
## 0.35.0 (January 11, 2024)
#### Added
* Add `*args` to `Middleware` and improve its type hints [#2381](https://github.com/Kludex/starlette/pull/2381).
#### Fixed
* Use `Iterable` instead `Iterator` on `iterate_in_threadpool` [#2362](https://github.com/Kludex/starlette/pull/2362).
#### Changes
* Handle `root_path` to keep compatibility with mounted ASGI applications and WSGI [#2400](https://github.com/Kludex/starlette/pull/2400).
* Turn `scope["client"]` to `None` on `TestClient` [#2377](https://github.com/Kludex/starlette/pull/2377).
## 0.34.0 (December 16, 2023)
### Added
* Use `ParamSpec` for `run_in_threadpool` [#2375](https://github.com/Kludex/starlette/pull/2375).
* Add `UploadFile.__repr__` [#2360](https://github.com/Kludex/starlette/pull/2360).
### Fixed
* Merge URLs properly on `TestClient` [#2376](https://github.com/Kludex/starlette/pull/2376).
* Take weak ETags in consideration on `StaticFiles` [#2334](https://github.com/Kludex/starlette/pull/2334).
### Deprecated
* Deprecate `FileResponse(method=...)` parameter [#2366](https://github.com/Kludex/starlette/pull/2366).
## 0.33.0 (December 1, 2023)
### Added
* Add `middleware` per `Route`/`WebSocketRoute` [#2349](https://github.com/Kludex/starlette/pull/2349).
* Add `middleware` per `Router` [#2351](https://github.com/Kludex/starlette/pull/2351).
### Fixed
* Do not overwrite `"path"` and `"root_path"` scope keys [#2352](https://github.com/Kludex/starlette/pull/2352).
* Set `ensure_ascii=False` on `json.dumps()` for `WebSocket.send_json()` [#2341](https://github.com/Kludex/starlette/pull/2341).
## 0.32.0.post1 (November 5, 2023)
### Fixed
* Revert mkdocs-material from 9.1.17 to 9.4.7 [#2326](https://github.com/Kludex/starlette/pull/2326).
## 0.32.0 (November 4, 2023)
### Added
* Send `reason` on `WebSocketDisconnect` [#2309](https://github.com/Kludex/starlette/pull/2309).
* Add `domain` parameter to `SessionMiddleware` [#2280](https://github.com/Kludex/starlette/pull/2280).
### Changed
* Inherit from `HTMLResponse` instead of `Response` on `_TemplateResponse` [#2274](https://github.com/Kludex/starlette/pull/2274).
* Restore the `Response.render` type annotation to its pre-0.31.0 state [#2264](https://github.com/Kludex/starlette/pull/2264).
## 0.31.1 (August 26, 2023)
### Fixed
* Fix import error when `exceptiongroup` isn't available [#2231](https://github.com/Kludex/starlette/pull/2231).
* Set `url_for` global for custom Jinja environments [#2230](https://github.com/Kludex/starlette/pull/2230).
## 0.31.0 (July 24, 2023)
### Added
* Officially support Python 3.12 [#2214](https://github.com/Kludex/starlette/pull/2214).
* Support AnyIO 4.0 [#2211](https://github.com/Kludex/starlette/pull/2211).
* Strictly type annotate Starlette (strict mode on mypy) [#2180](https://github.com/Kludex/starlette/pull/2180).
### Fixed
* Don't group duplicated headers on a single string when using the `TestClient` [#2219](https://github.com/Kludex/starlette/pull/2219).
## 0.30.0 (July 13, 2023)
### Removed
* Drop Python 3.7 support [#2178](https://github.com/Kludex/starlette/pull/2178).
## 0.29.0 (July 13, 2023)
### Added
* Add `follow_redirects` parameter to `TestClient` [#2207](https://github.com/Kludex/starlette/pull/2207).
* Add `__str__` to `HTTPException` and `WebSocketException` [#2181](https://github.com/Kludex/starlette/pull/2181).
* Warn users when using `lifespan` together with `on_startup`/`on_shutdown` [#2193](https://github.com/Kludex/starlette/pull/2193).
* Collect routes from `Host` to generate the OpenAPI schema [#2183](https://github.com/Kludex/starlette/pull/2183).
* Add `request` argument to `TemplateResponse` [#2191](https://github.com/Kludex/starlette/pull/2191).
### Fixed
* Stop `body_stream` in case `more_body=False` on `BaseHTTPMiddleware` [#2194](https://github.com/Kludex/starlette/pull/2194).
## 0.28.0 (June 7, 2023)
### Changed
* Reuse `Request`'s body buffer for call_next in `BaseHTTPMiddleware` [#1692](https://github.com/Kludex/starlette/pull/1692).
* Move exception handling logic to `Route` [#2026](https://github.com/Kludex/starlette/pull/2026).
### Added
* Add `env` parameter to `Jinja2Templates`, and deprecate `**env_options` [#2159](https://github.com/Kludex/starlette/pull/2159).
* Add clear error message when `httpx` is not installed [#2177](https://github.com/Kludex/starlette/pull/2177).
### Fixed
* Allow "name" argument on `templates url_for()` [#2127](https://github.com/Kludex/starlette/pull/2127).
## 0.27.0 (May 16, 2023)
This release fixes a path traversal vulnerability in `StaticFiles`. You can view the full security advisory:
https://github.com/Kludex/starlette/security/advisories/GHSA-v5gw-mw7f-84px
### Added
* Minify JSON websocket data via `send_json` https://github.com/Kludex/starlette/pull/2128
### Fixed
* Replace `commonprefix` by `commonpath` on `StaticFiles` [1797de4](https://github.com/Kludex/starlette/commit/1797de464124b090f10cf570441e8292936d63e3).
* Convert ImportErrors into ModuleNotFoundError [#2135](https://github.com/Kludex/starlette/pull/2135).
* Correct the RuntimeError message content in websockets [#2141](https://github.com/Kludex/starlette/pull/2141).
## 0.26.1 (March 13, 2023)
### Fixed
* Fix typing of Lifespan to allow subclasses of Starlette [#2077](https://github.com/Kludex/starlette/pull/2077).
## 0.26.0.post1 (March 9, 2023)
### Fixed
* Replace reference from Events to Lifespan on the mkdocs.yml [#2072](https://github.com/Kludex/starlette/pull/2072).
## 0.26.0 (March 9, 2023)
### Added
* Support [lifespan state](lifespan.md) [#2060](https://github.com/Kludex/starlette/pull/2060),
[#2065](https://github.com/Kludex/starlette/pull/2065) and [#2064](https://github.com/Kludex/starlette/pull/2064).
### Changed
* Change `url_for` signature to return a `URL` instance [#1385](https://github.com/Kludex/starlette/pull/1385).
### Fixed
* Allow "name" argument on `url_for()` and `url_path_for()` [#2050](https://github.com/Kludex/starlette/pull/2050).
### Deprecated
* Deprecate `on_startup` and `on_shutdown` events [#2070](https://github.com/Kludex/starlette/pull/2070).
## 0.25.0 (February 14, 2023)
### Fix
* Limit the number of fields and files when parsing `multipart/form-data` on the `MultipartParser` [8c74c2c](https://github.com/Kludex/starlette/commit/8c74c2c8dba7030154f8af18e016136bea1938fa) and [#2036](https://github.com/Kludex/starlette/pull/2036).
## 0.24.0 (February 6, 2023)
### Added
* Allow `StaticFiles` to follow symlinks [#1683](https://github.com/Kludex/starlette/pull/1683).
* Allow `Request.form()` as a context manager [#1903](https://github.com/Kludex/starlette/pull/1903).
* Add `size` attribute to `UploadFile` [#1405](https://github.com/Kludex/starlette/pull/1405).
* Add `env_prefix` argument to `Config` [#1990](https://github.com/Kludex/starlette/pull/1990).
* Add template context processors [#1904](https://github.com/Kludex/starlette/pull/1904).
* Support `str` and `datetime` on `expires` parameter on the `Response.set_cookie` method [#1908](https://github.com/Kludex/starlette/pull/1908).
### Changed
* Lazily build the middleware stack [#2017](https://github.com/Kludex/starlette/pull/2017).
* Make the `file` argument required on `UploadFile` [#1413](https://github.com/Kludex/starlette/pull/1413).
* Use debug extension instead of custom response template extension [#1991](https://github.com/Kludex/starlette/pull/1991).
### Fixed
* Fix url parsing of ipv6 urls on `URL.replace` [#1965](https://github.com/Kludex/starlette/pull/1965).
## 0.23.1 (December 9, 2022)
### Fixed
* Only stop receiving stream on `body_stream` if body is empty on the `BaseHTTPMiddleware` [#1940](https://github.com/Kludex/starlette/pull/1940).
## 0.23.0 (December 5, 2022)
### Added
* Add `headers` parameter to the `TestClient` [#1966](https://github.com/Kludex/starlette/pull/1966).
### Deprecated
* Deprecate `Starlette` and `Router` decorators [#1897](https://github.com/Kludex/starlette/pull/1897).
### Fixed
* Fix bug on `FloatConvertor` regex [#1973](https://github.com/Kludex/starlette/pull/1973).
## 0.22.0 (November 17, 2022)
### Changed
* Bypass `GZipMiddleware` when response includes `Content-Encoding` [#1901](https://github.com/Kludex/starlette/pull/1901).
### Fixed
* Remove unneeded `unquote()` from query parameters on the `TestClient` [#1953](https://github.com/Kludex/starlette/pull/1953).
* Make sure `MutableHeaders._list` is actually a `list` [#1917](https://github.com/Kludex/starlette/pull/1917).
* Import compatibility with the next version of `AnyIO` [#1936](https://github.com/Kludex/starlette/pull/1936).
## 0.21.0 (September 26, 2022)
This release replaces the underlying HTTP client used on the `TestClient` (`requests` :arrow_right: `httpx`), and as those clients [differ _a bit_ on their API](https://www.python-httpx.org/compatibility/), your test suite will likely break. To make the migration smoother, you can use the [`bump-testclient`](https://github.com/Kludex/bump-testclient) tool.
### Changed
* Replace `requests` with `httpx` in `TestClient` [#1376](https://github.com/Kludex/starlette/pull/1376).
### Added
* Add `WebSocketException` and support for WebSocket exception handlers [#1263](https://github.com/Kludex/starlette/pull/1263).
* Add `middleware` parameter to `Mount` class [#1649](https://github.com/Kludex/starlette/pull/1649).
* Officially support Python 3.11 [#1863](https://github.com/Kludex/starlette/pull/1863).
* Implement `__repr__` for route classes [#1864](https://github.com/Kludex/starlette/pull/1864).
### Fixed
* Fix bug on which `BackgroundTasks` were cancelled when using `BaseHTTPMiddleware` and client disconnected [#1715](https://github.com/Kludex/starlette/pull/1715).
## 0.20.4 (June 28, 2022)
### Fixed
* Remove converter from path when generating OpenAPI schema [#1648](https://github.com/Kludex/starlette/pull/1648).
## 0.20.3 (June 10, 2022)
### Fixed
* Revert "Allow `StaticFiles` to follow symlinks" [#1681](https://github.com/Kludex/starlette/pull/1681).
## 0.20.2 (June 7, 2022)
### Fixed
* Fix regression on route paths with colons [#1675](https://github.com/Kludex/starlette/pull/1675).
* Allow `StaticFiles` to follow symlinks [#1337](https://github.com/Kludex/starlette/pull/1377).
## 0.20.1 (May 28, 2022)
### Fixed
* Improve detection of async callables [#1444](https://github.com/Kludex/starlette/pull/1444).
* Send 400 (Bad Request) when `boundary` is missing [#1617](https://github.com/Kludex/starlette/pull/1617).
* Send 400 (Bad Request) when missing "name" field on `Content-Disposition` header [#1643](https://github.com/Kludex/starlette/pull/1643).
* Do not send empty data to `StreamingResponse` on `BaseHTTPMiddleware` [#1609](https://github.com/Kludex/starlette/pull/1609).
* Add `__bool__` dunder for `Secret` [#1625](https://github.com/Kludex/starlette/pull/1625).
## 0.20.0 (May 3, 2022)
### Removed
* Drop Python 3.6 support [#1357](https://github.com/Kludex/starlette/pull/1357) and [#1616](https://github.com/Kludex/starlette/pull/1616).
## 0.19.1 (April 22, 2022)
### Fixed
* Fix inference of `Route.name` when created from methods [#1553](https://github.com/Kludex/starlette/pull/1553).
* Avoid `TypeError` on `websocket.disconnect` when code is `None` [#1574](https://github.com/Kludex/starlette/pull/1574).
### Deprecated
* Deprecate `WS_1004_NO_STATUS_RCVD` and `WS_1005_ABNORMAL_CLOSURE` in favor of `WS_1005_NO_STATUS_RCVD` and `WS_1006_ABNORMAL_CLOSURE`, as the previous constants didn't match the [WebSockets specs](https://www.iana.org/assignments/websocket/websocket.xhtml) [#1580](https://github.com/Kludex/starlette/pull/1580).
## 0.19.0 (March 9, 2022)
### Added
* Error handler will always run, even if the error happens on a background task [#761](https://github.com/Kludex/starlette/pull/761).
* Add `headers` parameter to `HTTPException` [#1435](https://github.com/Kludex/starlette/pull/1435).
* Internal responses with `405` status code insert an `Allow` header, as described by [RFC 7231](https://datatracker.ietf.org/doc/html/rfc7231#section-6.5.5) [#1436](https://github.com/Kludex/starlette/pull/1436).
* The `content` argument in `JSONResponse` is now required [#1431](https://github.com/Kludex/starlette/pull/1431).
* Add custom URL convertor register [#1437](https://github.com/Kludex/starlette/pull/1437).
* Add content disposition type parameter to `FileResponse` [#1266](https://github.com/Kludex/starlette/pull/1266).
* Add next query param with original request URL in requires decorator [#920](https://github.com/Kludex/starlette/pull/920).
* Add `raw_path` to `TestClient` scope [#1445](https://github.com/Kludex/starlette/pull/1445).
* Add union operators to `MutableHeaders` [#1240](https://github.com/Kludex/starlette/pull/1240).
* Display missing route details on debug page [#1363](https://github.com/Kludex/starlette/pull/1363).
* Change `anyio` required version range to `>=3.4.0,<5.0` [#1421](https://github.com/Kludex/starlette/pull/1421) and [#1460](https://github.com/Kludex/starlette/pull/1460).
* Add `typing-extensions>=3.10` requirement - used only on lower versions than Python 3.10 [#1475](https://github.com/Kludex/starlette/pull/1475).
### Fixed
* Prevent `BaseHTTPMiddleware` from hiding errors of `StreamingResponse` and mounted applications [#1459](https://github.com/Kludex/starlette/pull/1459).
* `SessionMiddleware` uses an explicit `path=...`, instead of defaulting to the ASGI 'root_path' [#1512](https://github.com/Kludex/starlette/pull/1512).
* `Request.client` is now compliant with the ASGI specifications [#1462](https://github.com/Kludex/starlette/pull/1462).
* Raise `KeyError` at early stage for missing boundary [#1349](https://github.com/Kludex/starlette/pull/1349).
### Deprecated
* Deprecate WSGIMiddleware in favor of a2wsgi [#1504](https://github.com/Kludex/starlette/pull/1504).
* Deprecate `run_until_first_complete` [#1443](https://github.com/Kludex/starlette/pull/1443).
## 0.18.0 (January 23, 2022)
### Added
* Change default chunk size from 4Kb to 64Kb on `FileResponse` [#1345](https://github.com/Kludex/starlette/pull/1345).
* Add support for `functools.partial` in `WebSocketRoute` [#1356](https://github.com/Kludex/starlette/pull/1356).
* Add `StaticFiles` packages with directory [#1350](https://github.com/Kludex/starlette/pull/1350).
* Allow environment options in `Jinja2Templates` [#1401](https://github.com/Kludex/starlette/pull/1401).
* Allow HEAD method on `HttpEndpoint` [#1346](https://github.com/Kludex/starlette/pull/1346).
* Accept additional headers on `websocket.accept` message [#1361](https://github.com/Kludex/starlette/pull/1361) and [#1422](https://github.com/Kludex/starlette/pull/1422).
* Add `reason` to `WebSocket` close ASGI event [#1417](https://github.com/Kludex/starlette/pull/1417).
* Add headers attribute to `UploadFile` [#1382](https://github.com/Kludex/starlette/pull/1382).
* Don't omit `Content-Length` header for `Content-Length: 0` cases [#1395](https://github.com/Kludex/starlette/pull/1395).
* Don't set headers for responses with 1xx, 204 and 304 status code [#1397](https://github.com/Kludex/starlette/pull/1397).
* `SessionMiddleware.max_age` now accepts `None`, so cookie can last as long as the browser session [#1387](https://github.com/Kludex/starlette/pull/1387).
### Fixed
* Tweak `hashlib.md5()` function on `FileResponse`s ETag generation. The parameter [`usedforsecurity`](https://bugs.python.org/issue9216) flag is set to `False`, if the flag is available on the system. This fixes an error raised on systems with [FIPS](https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/FIPS_Mode_-_an_explanation) enabled [#1366](https://github.com/Kludex/starlette/pull/1366) and [#1410](https://github.com/Kludex/starlette/pull/1410).
* Fix `path_params` type on `url_path_for()` method i.e. turn `str` into `Any` [#1341](https://github.com/Kludex/starlette/pull/1341).
* `Host` now ignores `port` on routing [#1322](https://github.com/Kludex/starlette/pull/1322).
## 0.17.1 (November 17, 2021)
### Fixed
* Fix `IndexError` in authentication `requires` when wrapped function arguments are distributed between `*args` and `**kwargs` [#1335](https://github.com/Kludex/starlette/pull/1335).
## 0.17.0 (November 4, 2021)
### Added
* `Response.delete_cookie` now accepts the same parameters as `Response.set_cookie` [#1228](https://github.com/Kludex/starlette/pull/1228).
* Update the `Jinja2Templates` constructor to allow `PathLike` [#1292](https://github.com/Kludex/starlette/pull/1292).
### Fixed
* Fix BadSignature exception handling in SessionMiddleware [#1264](https://github.com/Kludex/starlette/pull/1264).
* Change `HTTPConnection.__getitem__` return type from `str` to `typing.Any` [#1118](https://github.com/Kludex/starlette/pull/1118).
* Change `ImmutableMultiDict.getlist` return type from `typing.List[str]` to `typing.List[typing.Any]` [#1235](https://github.com/Kludex/starlette/pull/1235).
* Handle `OSError` exceptions on `StaticFiles` [#1220](https://github.com/Kludex/starlette/pull/1220).
* Fix `StaticFiles` 404.html in HTML mode [#1314](https://github.com/Kludex/starlette/pull/1314).
* Prevent anyio.ExceptionGroup in error views under a BaseHTTPMiddleware [#1262](https://github.com/Kludex/starlette/pull/1262).
### Removed
* Remove GraphQL support [#1198](https://github.com/Kludex/starlette/pull/1198).
## 0.16.0 (July 19, 2021)
### Added
* Added [Encode](https://github.com/sponsors/encode) funding option
[#1219](https://github.com/Kludex/starlette/pull/1219)
### Fixed
* `starlette.websockets.WebSocket` instances are now hashable and compare by identity
[#1039](https://github.com/Kludex/starlette/pull/1039)
* A number of fixes related to running task groups in lifespan
[#1213](https://github.com/Kludex/starlette/pull/1213),
[#1227](https://github.com/Kludex/starlette/pull/1227)
### Deprecated/removed
* The method `starlette.templates.Jinja2Templates.get_env` was removed
[#1218](https://github.com/Kludex/starlette/pull/1218)
* The ClassVar `starlette.testclient.TestClient.async_backend` was removed,
the backend is now configured using constructor kwargs
[#1211](https://github.com/Kludex/starlette/pull/1211)
* Passing an Async Generator Function or a Generator Function to `starlette.routing.Router(lifespan=)` is deprecated. You should wrap your lifespan in `@contextlib.asynccontextmanager`.
[#1227](https://github.com/Kludex/starlette/pull/1227)
[#1110](https://github.com/Kludex/starlette/pull/1110)
## 0.15.0 (June 23, 2021)
This release includes major changes to the low-level asynchronous parts of Starlette. As a result,
**Starlette now depends on [AnyIO](https://anyio.readthedocs.io/en/stable/)** and some minor API
changes have occurred. Another significant change with this release is the
**deprecation of built-in GraphQL support**.
### Added
* Starlette now supports [Trio](https://trio.readthedocs.io/en/stable/) as an async runtime via
AnyIO - [#1157](https://github.com/Kludex/starlette/pull/1157).
* `TestClient.websocket_connect()` now must be used as a context manager.
* Initial support for Python 3.10 - [#1201](https://github.com/Kludex/starlette/pull/1201).
* The compression level used in `GZipMiddleware` is now adjustable -
[#1128](https://github.com/Kludex/starlette/pull/1128).
### Fixed
* Several fixes to `CORSMiddleware`. See [#1111](https://github.com/Kludex/starlette/pull/1111),
[#1112](https://github.com/Kludex/starlette/pull/1112),
[#1113](https://github.com/Kludex/starlette/pull/1113),
[#1199](https://github.com/Kludex/starlette/pull/1199).
* Improved exception messages in the case of duplicated path parameter names -
[#1177](https://github.com/Kludex/starlette/pull/1177).
* `RedirectResponse` now uses `quote` instead of `quote_plus` encoding for the `Location` header
to better match the behaviour in other frameworks such as Django -
[#1164](https://github.com/Kludex/starlette/pull/1164).
* Exception causes are now preserved in more cases -
[#1158](https://github.com/Kludex/starlette/pull/1158).
* Session cookies now use the ASGI root path in the case of mounted applications -
[#1147](https://github.com/Kludex/starlette/pull/1147).
* Fixed a cache invalidation bug when static files were deleted in certain circumstances -
[#1023](https://github.com/Kludex/starlette/pull/1023).
* Improved memory usage of `BaseHTTPMiddleware` when handling large responses -
[#1012](https://github.com/Kludex/starlette/issues/1012) fixed via #1157
### Deprecated/removed
* Built-in GraphQL support via the `GraphQLApp` class has been deprecated and will be removed in a
future release. Please see [#619](https://github.com/Kludex/starlette/issues/619). GraphQL is not
supported on Python 3.10.
* The `executor` parameter to `GraphQLApp` was removed. Use `executor_class` instead.
* The `workers` parameter to `WSGIMiddleware` was removed. This hasn't had any effect since
Starlette v0.6.3.
## 0.14.2 (February 2, 2021)
### Fixed
* Fixed `ServerErrorMiddleware` compatibility with Python 3.9.1/3.8.7 when debug mode is enabled -
[#1132](https://github.com/Kludex/starlette/pull/1132).
* Fixed unclosed socket `ResourceWarning`s when using the `TestClient` with WebSocket endpoints -
#1132.
* Improved detection of `async` endpoints wrapped in `functools.partial` on Python 3.8+ -
[#1106](https://github.com/Kludex/starlette/pull/1106).
## 0.14.1 (November 9th, 2020)
### Removed
* `UJSONResponse` was removed (this change was intended to be included in 0.14.0). Please see the
[documentation](https://starlette.dev/responses/#custom-json-serialization) for how to
implement responses using custom JSON serialization -
[#1074](https://github.com/Kludex/starlette/pull/1047).
## 0.14.0 (November 8th, 2020)
### Added
* Starlette now officially supports Python3.9.
* In `StreamingResponse`, allow custom async iterator such as objects from classes implementing `__aiter__`.
* Allow usage of `functools.partial` async handlers in Python versions 3.6 and 3.7.
* Add 418 I'm A Teapot status code.
### Changed
* Create tasks from handler coroutines before sending them to `asyncio.wait`.
* Use `format_exception` instead of `format_tb` in `ServerErrorMiddleware`'s `debug` responses.
* Be more lenient with handler arguments when using the `requires` decorator.
## 0.13.8
* Revert `Queue(maxsize=1)` fix for `BaseHTTPMiddleware` middleware classes and streaming responses.
* The `StaticFiles` constructor now allows `pathlib.Path` in addition to strings for its `directory` argument.
## 0.13.7
* Fix high memory usage when using `BaseHTTPMiddleware` middleware classes and streaming responses.
## 0.13.6
* Fix 404 errors with `StaticFiles`.
## 0.13.5
* Add support for `Starlette(lifespan=...)` functions.
* More robust path-traversal check in StaticFiles app.
* Fix WSGI PATH_INFO encoding.
* RedirectResponse now accepts optional background parameter
* Allow path routes to contain regex meta characters
* Treat ASGI HTTP 'body' as an optional key.
* Don't use thread pooling for writing to in-memory upload files.
## 0.13.0
* Switch to promoting application configuration on init style everywhere.
This means dropping the decorator style in favour of declarative routing
tables and middleware definitions.
## 0.12.12
* Fix `request.url_for()` for the Mount-within-a-Mount case.
## 0.12.11
* Fix `request.url_for()` when an ASGI `root_path` is being used.
## 0.12.1
* Add `URL.include_query_params(**kwargs)`
* Add `URL.replace_query_params(**kwargs)`
* Add `URL.remove_query_params(param_names)`
* `request.state` properly persisting across middleware.
* Added `request.scope` interface.
## 0.12.0
* Switch to ASGI 3.0.
* Fixes to CORS middleware.
* Add `StaticFiles(html=True)` support.
* Fix path quoting in redirect responses.
## 0.11.1
* Add `request.state` interface, for storing arbitrary additional information.
* Support disabling GraphiQL with `GraphQLApp(..., graphiql=False)`.
## 0.11.0
* `DatabaseMiddleware` is now dropped in favour of `databases`
* Templates are no longer configured on the application instance. Use `templates = Jinja2Templates(directory=...)` and `return templates.TemplateResponse('index.html', {"request": request})`
* Schema generation is no longer attached to the application instance. Use `schemas = SchemaGenerator(...)` and `return schemas.OpenAPIResponse(request=request)`
* `LifespanMiddleware` is dropped in favor of router-based lifespan handling.
* Application instances now accept a `routes` argument, `Starlette(routes=[...])`
* Schema generation now includes mounted routes.
## 0.10.6
* Add `Lifespan` routing component.
## 0.10.5
* Ensure `templating` does not strictly require `jinja2` to be installed.
## 0.10.4
* Templates are now configured independently from the application instance. `templates = Jinja2Templates(directory=...)`. Existing API remains in place, but is no longer documented,
and will be deprecated in due course. See the template documentation for more details.
## 0.10.3
* Move to independent `databases` package instead of `DatabaseMiddleware`. Existing API
remains in place, but is no longer documented, and will be deprecated in due course.
## 0.10.2
* Don't drop explicit port numbers on redirects from `HTTPSRedirectMiddleware`.
## 0.10.1
* Add MySQL database support.
* Add host-based routing.
## 0.10.0
* WebSockets now default to sending/receiving JSON over text data frames. Use `.send_json(data, mode="binary")` and `.receive_json(mode="binary")` for binary framing.
* `GraphQLApp` now takes an `executor_class` argument, which should be used in preference to the existing `executor` argument. Resolves an issue with async executors being instantiated before the event loop was setup. The `executor` argument is expected to be deprecated in the next median or major release.
* Authentication and the `@requires` decorator now support WebSocket endpoints.
* `MultiDict` and `ImmutableMultiDict` classes are available in `uvicorn.datastructures`.
* `QueryParams` is now instantiated with standard dict-style `*args, **kwargs` arguments.
## 0.9.11
* Session cookies now include browser 'expires', in addition to the existing signed expiry.
* `request.form()` now returns a multi-dict interface.
* The query parameter multi-dict implementation now mirrors `dict` more correctly for the
behavior of `.keys()`, `.values()`, and `.items()` when multiple same-key items occur.
* Use `urlsplit` throughout in favor of `urlparse`.
## 0.9.10
* Support `@requires(...)` on class methods.
* Apply URL escaping to form data.
* Support `HEAD` requests automatically.
* Add `await request.is_disconnected()`.
* Pass operationName to GraphQL executor.
## 0.9.9
* Add `TemplateResponse`.
* Add `CommaSeparatedStrings` datatype.
* Add `BackgroundTasks` for multiple tasks.
* Common subclass for `Request` and `WebSocket`, to eg. share `session` functionality.
* Expose remote address with `request.client`.
## 0.9.8
* Add `request.database.executemany`.
## 0.9.7
* Ensure that `AuthenticationMiddleware` handles lifespan messages correctly.
## 0.9.6
* Add `AuthenticationMiddleware`, and `@requires()` decorator.
## 0.9.5
* Support either `str` or `Secret` for `SessionMiddleware(secret_key=...)`.
## 0.9.4
* Add `config.environ`.
* Add `datastructures.Secret`.
* Add `datastructures.DatabaseURL`.
## 0.9.3
* Add `config.Config(".env")`
## 0.9.2
* Add optional database support.
* Add `request` to GraphQL context.
* Hide any password component in `URL.__repr__`.
## 0.9.1
* Handle startup/shutdown errors properly.
## 0.9.0
* `TestClient` can now be used as a context manager, instead of `LifespanContext`.
* Lifespan is now handled as middleware. Startup and Shutdown events are
visible throughout the middleware stack.
## 0.8.8
* Better support for third-party API schema generators.
## 0.8.7
* Support chunked requests with TestClient.
* Cleanup asyncio tasks properly with WSGIMiddleware.
* Support using TestClient within endpoints, for service mocking.
## 0.8.6
* Session cookies are now set on the root path.
## 0.8.5
* Support URL convertors.
* Support HTTP 304 cache responses from `StaticFiles`.
* Resolve character escaping issue with form data.
## 0.8.4
* Default to empty body on responses.
## 0.8.3
* Add 'name' argument to `@app.route()`.
* Use 'Host' header for URL reconstruction.
## 0.8.2
### StaticFiles
* StaticFiles no longer reads the file for responses to `HEAD` requests.
## 0.8.1
### Templating
* Add a default templating configuration with Jinja2.
Allows the following:
```python
app = Starlette(template_directory="templates")
@app.route('/')
async def homepage(request):
# `url_for` is available inside the template.
template = app.get_template('index.html')
content = template.render(request=request)
return HTMLResponse(content)
```
## 0.8.0
### Exceptions
* Add support for `@app.exception_handler(404)`.
* Ensure handled exceptions are not seen as errors by the middleware stack.
### SessionMiddleware
* Add `max_age`, and use timestamp-signed cookies. Defaults to two weeks.
### Cookies
* Ensure cookies are strictly HTTP correct.
### StaticFiles
* Check directory exists on instantiation.
## 0.7.4
### Concurrency
* Add `starlette.concurrency.run_in_threadpool`. Now handles `contextvar` support.
## 0.7.3
### Routing
* Add `name=` support to `app.mount()`. This allows eg: `app.mount('/static', StaticFiles(directory='static'), name='static')`.
## 0.7.2
### Middleware
* Add support for `@app.middleware("http")` decorator.
### Routing
* Add "endpoint" to ASGI scope.
## 0.7.1
### Debug tracebacks
* Improve debug traceback information & styling.
### URL routing
* Support mounted URL lookups with "path=", eg. `url_for('static', path=...)`.
* Support nested URL lookups, eg. `url_for('admin:user', username=...)`.
* Add redirect slashes support.
* Add www redirect support.
### Background tasks
* Add background task support to `FileResponse` and `StreamingResponse`.
## 0.7.0
### API Schema support
* Add `app.schema_generator = SchemaGenerator(...)`.
* Add `app.schema` property.
* Add `OpenAPIResponse(...)`.
### GraphQL routing
* Drop `app.add_graphql_route("/", ...)` in favor of more consistent `app.add_route("/", GraphQLApp(...))`.
## 0.6.3
### Routing API
* Support routing to methods.
* Ensure `url_path_for` works with Mount('/{some_path_params}').
* Fix Router(default=) argument.
* Support repeated paths, like: `@app.route("/", methods=["GET"])`, `@app.route("/", methods=["POST"])`
* Use the default ThreadPoolExecutor for all sync endpoints.
## 0.6.2
### SessionMiddleware
Added support for `request.session`, with `SessionMiddleware`.
## 0.6.1
### BaseHTTPMiddleware
Added support for `BaseHTTPMiddleware`, which provides a standard
request/response interface over a regular ASGI middleware.
This means you can write ASGI middleware while still working at
a request/response level, rather than handling ASGI messages directly.
```python
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers['Custom-Header'] = 'Example'
return response
app = Starlette()
app.add_middleware(CustomMiddleware)
```
## 0.6.0
### request.path_params
The biggest change in 0.6 is that endpoint signatures are no longer:
```python
async def func(request: Request, **kwargs) -> Response
```
Instead we just use:
```python
async def func(request: Request) -> Response
```
The path parameters are available on the request as `request.path_params`.
This is different to most Python webframeworks, but I think it actually ends up
being much more nicely consistent all the way through.
### request.url_for()
Request and WebSocketSession now support URL reversing with `request.url_for(name, **path_params)`.
This method returns a fully qualified `URL` instance.
The URL instance is a string-like object.
### app.url_path_for()
Applications now support URL path reversing with `app.url_path_for(name, **path_params)`.
This method returns a `URL` instance with the path and scheme set.
The URL instance is a string-like object, and will return only the path if coerced to a string.
### app.routes
Applications now support a `.routes` parameter, which returns a list of `[Route|WebSocketRoute|Mount]`.
### Route, WebSocketRoute, Mount
The low level components to `Router` now match the `@app.route()`, `@app.websocket_route()`, and `app.mount()` signatures.
================================================
FILE: docs/requests.md
================================================
Starlette includes a `Request` class that gives you a nicer interface onto
the incoming request, rather than accessing the ASGI scope and receive channel directly.
### Request
Signature: `Request(scope, receive=None)`
```python
from starlette.requests import Request
from starlette.responses import Response
async def app(scope, receive, send):
assert scope['type'] == 'http'
request = Request(scope, receive)
content = '%s %s' % (request.method, request.url.path)
response = Response(content, media_type='text/plain')
await response(scope, receive, send)
```
Requests present a mapping interface, so you can use them in the same
way as a `scope`.
For instance: `request['path']` will return the ASGI path.
If you don't need to access the request body you can instantiate a request
without providing an argument to `receive`.
#### Method
The request method is accessed as `request.method`.
#### URL
The request URL is accessed as `request.url`.
The property is a string-like object that exposes all the
components that can be parsed out of the URL.
For example: `request.url.path`, `request.url.port`, `request.url.scheme`.
#### Headers
Headers are exposed as an immutable, case-insensitive, multi-dict.
For example: `request.headers['content-type']`
#### Query Parameters
Query parameters are exposed as an immutable multi-dict.
For example: `request.query_params['search']`
#### Path Parameters
Router path parameters are exposed as a dictionary interface.
For example: `request.path_params['username']`
#### Client Address
The client's remote address is exposed as a named two-tuple `request.client` (or `None`).
The hostname or IP address: `request.client.host`
The port number from which the client is connecting: `request.client.port`
#### Cookies
Cookies are exposed as a regular dictionary interface.
For example: `request.cookies.get('mycookie')`
Cookies are ignored in case of an invalid cookie. (RFC2109)
#### Body
There are a few different interfaces for returning the body of the request:
The request body as bytes: `await request.body()`
The request body, parsed as form data or multipart: `async with request.form() as form:`
The request body, parsed as JSON: `await request.json()`
You can also access the request body as a stream, using the `async for` syntax:
```python
from starlette.requests import Request
from starlette.responses import Response
async def app(scope, receive, send):
assert scope['type'] == 'http'
request = Request(scope, receive)
body = b''
async for chunk in request.stream():
body += chunk
response = Response(body, media_type='text/plain')
await response(scope, receive, send)
```
If you access `.stream()` then the byte chunks are provided without storing
the entire body to memory. Any subsequent calls to `.body()`, `.form()`, or `.json()`
will raise an error.
In some cases such as long-polling, or streaming responses you might need to
determine if the client has dropped the connection. You can determine this
state with `disconnected = await request.is_disconnected()`.
#### Request Files
Request files are normally sent as multipart form data (`multipart/form-data`).
Signature: `request.form(max_files=1000, max_fields=1000, max_part_size=1024*1024)`
You can configure the number of maximum fields or files with the parameters `max_files` and `max_fields`; and part size using `max_part_size`:
```python
async with request.form(max_files=1000, max_fields=1000, max_part_size=1024*1024):
...
```
!!! info
These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields.
When you call `async with request.form() as form` you receive a `starlette.datastructures.FormData` which is an immutable
multidict, containing both file uploads and text input. File upload items are represented as instances of `starlette.datastructures.UploadFile`.
`UploadFile` has the following attributes:
* `filename`: An `str` with the original file name that was uploaded or `None` if its not available (e.g. `myimage.jpg`).
* `content_type`: An `str` with the content type (MIME type / media type) or `None` if it's not available (e.g. `image/jpeg`).
* `file`: A `SpooledTemporaryFile` (a file-like object). This is the actual Python file that you can pass directly to other functions or libraries that expect a "file-like" object.
* `headers`: A `Headers` object. Often this will only be the `Content-Type` header, but if additional headers were included in the multipart field they will be included here. Note that these headers have no relationship with the headers in `Request.headers`.
* `size`: An `int` with uploaded file's size in bytes. This value is calculated from request's contents, making it better choice to find uploaded file's size than `Content-Length` header. `None` if not set.
`UploadFile` has the following `async` methods. They all call the corresponding file methods underneath (using the internal `SpooledTemporaryFile`).
* `async write(data)`: Writes `data` (`bytes`) to the file.
* `async read(size)`: Reads `size` (`int`) bytes of the file.
* `async seek(offset)`: Goes to the byte position `offset` (`int`) in the file.
* E.g., `await myfile.seek(0)` would go to the start of the file.
* `async close()`: Closes the file.
As all these methods are `async` methods, you need to "await" them.
For example, you can get the file name and the contents with:
```python
async with request.form() as form:
filename = form["upload_file"].filename
contents = await form["upload_file"].read()
```
!!! info
As settled in [RFC-7578: 4.2](https://www.ietf.org/rfc/rfc7578.txt), form-data content part that contains file
assumed to have `name` and `filename` fields in `Content-Disposition` header: `Content-Disposition: form-data;
name="user"; filename="somefile"`. Though `filename` field is optional according to RFC-7578, it helps
Starlette to differentiate which data should be treated as file. If `filename` field was supplied, `UploadFile`
object will be created to access underlying file, otherwise form-data part will be parsed and available as a raw
string.
#### Application
The originating Starlette application can be accessed via `request.app`.
#### Other state
If you want to store additional information on the request you can do so
using `request.state`.
For example:
`request.state.time_started = time.time()`
================================================
FILE: docs/responses.md
================================================
Starlette includes a few response classes that handle sending back the
appropriate ASGI messages on the `send` channel.
### Response
Signature: `Response(content, status_code=200, headers=None, media_type=None)`
* `content` - A string or bytestring.
* `status_code` - An integer HTTP status code.
* `headers` - A dictionary of strings.
* `media_type` - A string giving the media type. eg. "text/html"
Starlette will automatically include a Content-Length header. It will also
include a Content-Type header, based on the media_type and appending a charset
for text types, unless a charset has already been specified in the `media_type`.
Once you've instantiated a response, you can send it by calling it as an
ASGI application instance.
```python
from starlette.responses import Response
async def app(scope, receive, send):
assert scope['type'] == 'http'
response = Response('Hello, world!', media_type='text/plain')
await response(scope, receive, send)
```
#### Set Cookie
Starlette provides a `set_cookie` method to allow you to set cookies on the response object.
Signature: `Response.set_cookie(key, value, max_age=None, expires=None, path="/", domain=None, secure=False, httponly=False, samesite="lax", partitioned=False)`
* `key` - A string that will be the cookie's key.
* `value` - A string that will be the cookie's value.
* `max_age` - An integer that defines the lifetime of the cookie in seconds. A negative integer or a value of `0` will discard the cookie immediately. `Optional`
* `expires` - Either an integer that defines the number of seconds until the cookie expires, or a datetime. `Optional`
* `path` - A string that specifies the subset of routes to which the cookie will apply. `Optional`
* `domain` - A string that specifies the domain for which the cookie is valid. `Optional`
* `secure` - A bool indicating that the cookie will only be sent to the server if request is made using SSL and the HTTPS protocol. `Optional`
* `httponly` - A bool indicating that the cookie cannot be accessed via JavaScript through `Document.cookie` property, the `XMLHttpRequest` or `Request` APIs. `Optional`
* `samesite` - A string that specifies the samesite strategy for the cookie. Valid values are `'lax'`, `'strict'` and `'none'`. Defaults to `'lax'`. `Optional`
* `partitioned` - A bool that indicates to user agents that these cross-site cookies should only be available in the same top-level context that the cookie was first set in. Only available for Python 3.14+, otherwise an error will be raised. `Optional`
#### Delete Cookie
Conversely, Starlette also provides a `delete_cookie` method to manually expire a set cookie.
Signature: `Response.delete_cookie(key, path='/', domain=None)`
### HTMLResponse
Takes some text or bytes and returns an HTML response.
```python
from starlette.responses import HTMLResponse
async def app(scope, receive, send):
assert scope['type'] == 'http'
response = HTMLResponse('
Hello, world!
')
await response(scope, receive, send)
```
### PlainTextResponse
Takes some text or bytes and returns a plain text response.
```python
from starlette.responses import PlainTextResponse
async def app(scope, receive, send):
assert scope['type'] == 'http'
response = PlainTextResponse('Hello, world!')
await response(scope, receive, send)
```
### JSONResponse
Takes some data and returns an `application/json` encoded response.
```python
from starlette.responses import JSONResponse
async def app(scope, receive, send):
assert scope['type'] == 'http'
response = JSONResponse({'hello': 'world'})
await response(scope, receive, send)
```
#### Custom JSON serialization
If you need fine-grained control over JSON serialization, you can subclass
`JSONResponse` and override the `render` method.
For example, if you wanted to use a third-party JSON library such as
[orjson](https://pypi.org/project/orjson/):
```python
from typing import Any
import orjson
from starlette.responses import JSONResponse
class OrjsonResponse(JSONResponse):
def render(self, content: Any) -> bytes:
return orjson.dumps(content)
```
In general you *probably* want to stick with `JSONResponse` by default unless
you are micro-optimising a particular endpoint or need to serialize non-standard
object types.
### RedirectResponse
Returns an HTTP redirect. Uses a 307 status code by default.
```python
from starlette.responses import PlainTextResponse, RedirectResponse
async def app(scope, receive, send):
assert scope['type'] == 'http'
if scope['path'] != '/':
response = RedirectResponse(url='/')
else:
response = PlainTextResponse('Hello, world!')
await response(scope, receive, send)
```
### StreamingResponse
Takes an async generator or a normal generator/iterator and streams the response body.
```python
from starlette.responses import StreamingResponse
import asyncio
async def slow_numbers(minimum, maximum):
yield '
'
for number in range(minimum, maximum + 1):
yield '
%d
' % number
await asyncio.sleep(0.5)
yield '
'
async def app(scope, receive, send):
assert scope['type'] == 'http'
generator = slow_numbers(1, 10)
response = StreamingResponse(generator, media_type='text/html')
await response(scope, receive, send)
```
Have in mind that file-like objects (like those created by `open()`) are normal iterators. So, you can return them directly in a `StreamingResponse`.
### FileResponse
Asynchronously streams a file as the response.
Takes a different set of arguments to instantiate than the other response types:
* `path` - The filepath to the file to stream.
* `headers` - Any custom headers to include, as a dictionary.
* `media_type` - A string giving the media type. If unset, the filename or path will be used to infer a media type.
* `filename` - If set, this will be included in the response `Content-Disposition`.
* `content_disposition_type` - will be included in the response `Content-Disposition`. Can be set to "attachment" (default) or "inline".
File responses will include appropriate `Content-Length`, `Last-Modified` and `ETag` headers.
```python
from starlette.responses import FileResponse
async def app(scope, receive, send):
assert scope['type'] == 'http'
response = FileResponse('statics/favicon.ico')
await response(scope, receive, send)
```
File responses also supports [HTTP range requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests).
The `Accept-Ranges: bytes` header will be included in the response if the file exists. For now, only the `bytes`
range unit is supported.
If the request includes a `Range` header, and the file exists, the response will be a `206 Partial Content` response
with the requested range of bytes. If the range is invalid, the response will be a `416 Range Not Satisfiable` response.
## Third party responses
#### [EventSourceResponse](https://github.com/sysid/sse-starlette)
A response class that implements [Server-Sent Events](https://html.spec.whatwg.org/multipage/server-sent-events.html). It enables event streaming from the server to the client without the complexity of websockets.
================================================
FILE: docs/routing.md
================================================
## HTTP Routing
Starlette has a simple but capable request routing system. A routing table
is defined as a list of routes, and passed when instantiating the application.
```python
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from starlette.routing import Route
async def homepage(request):
return PlainTextResponse("Homepage")
async def about(request):
return PlainTextResponse("About")
routes = [
Route("/", endpoint=homepage),
Route("/about", endpoint=about),
]
app = Starlette(routes=routes)
```
The `endpoint` argument can be one of:
* A regular function or async function, which accepts a single `request`
argument and which should return a response.
* A class that implements the ASGI interface, such as Starlette's [HTTPEndpoint](endpoints.md#httpendpoint).
## Path Parameters
Paths can use URI templating style to capture path components.
```python
Route('/users/{username}', user)
```
By default this will capture characters up to the end of the path or the next `/`.
You can use convertors to modify what is captured. The available convertors are:
* `str` returns a string, and is the default.
* `int` returns a Python integer.
* `float` returns a Python float.
* `uuid` return a Python `uuid.UUID` instance.
* `path` returns the rest of the path, including any additional `/` characters.
Convertors are used by prefixing them with a colon, like so:
```python
Route('/users/{user_id:int}', user)
Route('/floating-point/{number:float}', floating_point)
Route('/uploaded/{rest_of_path:path}', uploaded)
```
If you need a different converter that is not defined, you can create your own.
See below an example on how to create a `datetime` convertor, and how to register it:
```python
from datetime import datetime
from starlette.convertors import Convertor, register_url_convertor
class DateTimeConvertor(Convertor):
regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(.[0-9]+)?"
def convert(self, value: str) -> datetime:
return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S")
def to_string(self, value: datetime) -> str:
return value.strftime("%Y-%m-%dT%H:%M:%S")
register_url_convertor("datetime", DateTimeConvertor())
```
After registering it, you'll be able to use it as:
```python
Route('/history/{date:datetime}', history)
```
Path parameters are made available in the request, as the `request.path_params`
dictionary.
```python
async def user(request):
user_id = request.path_params['user_id']
...
```
## Handling HTTP methods
Routes can also specify which HTTP methods are handled by an endpoint:
```python
Route('/users/{user_id:int}', user, methods=["GET", "POST"])
```
By default function endpoints will only accept `GET` requests, unless specified.
## Submounting routes
In large applications you might find that you want to break out parts of the
routing table, based on a common path prefix.
```python
routes = [
Route('/', homepage),
Mount('/users', routes=[
Route('/', users, methods=['GET', 'POST']),
Route('/{username}', user),
])
]
```
This style allows you to define different subsets of the routing table in
different parts of your project.
```python
from myproject import users, auth
routes = [
Route('/', homepage),
Mount('/users', routes=users.routes),
Mount('/auth', routes=auth.routes),
]
```
You can also use mounting to include sub-applications within your Starlette
application. For example...
```python
# This is a standalone static files server:
app = StaticFiles(directory="static")
# This is a static files server mounted within a Starlette application,
# underneath the "/static" path.
routes = [
...
Mount("/static", app=StaticFiles(directory="static"), name="static")
]
app = Starlette(routes=routes)
```
## Reverse URL lookups
You'll often want to be able to generate the URL for a particular route,
such as in cases where you need to return a redirect response.
* Signature: `url_for(name, **path_params) -> URL`
```python
routes = [
Route("/", homepage, name="homepage")
]
# We can use the following to return a URL...
url = request.url_for("homepage")
```
URL lookups can include path parameters...
```python
routes = [
Route("/users/{username}", user, name="user_detail")
]
# We can use the following to return a URL...
url = request.url_for("user_detail", username=...)
```
If a `Mount` includes a `name`, then submounts should use a `{prefix}:{name}`
style for reverse URL lookups.
```python
routes = [
Mount("/users", name="users", routes=[
Route("/", user, name="user_list"),
Route("/{username}", user, name="user_detail")
])
]
# We can use the following to return URLs...
url = request.url_for("users:user_list")
url = request.url_for("users:user_detail", username=...)
```
Mounted applications may include a `path=...` parameter.
```python
routes = [
...
Mount("/static", app=StaticFiles(directory="static"), name="static")
]
# We can use the following to return URLs...
url = request.url_for("static", path="/css/base.css")
```
For cases where there is no `request` instance, you can make reverse lookups
against the application, although these will only return the URL path.
```python
url = app.url_path_for("user_detail", username=...)
```
## Host-based routing
If you want to use different routes for the same path based on the `Host` header.
Note that port is removed from the `Host` header when matching.
For example, `Host (host='example.org:3600', ...)` will be processed
even if the `Host` header contains or does not contain a port other than `3600`
(`example.org:5600`, `example.org`).
Therefore, you can specify the port if you need it for use in `url_for`.
There are several ways to connect host-based routes to your application
```python
site = Router() # Use eg. `@site.route()` to configure this.
api = Router() # Use eg. `@api.route()` to configure this.
news = Router() # Use eg. `@news.route()` to configure this.
routes = [
Host('api.example.org', api, name="site_api")
]
app = Starlette(routes=routes)
app.host('www.example.org', site, name="main_site")
news_host = Host('news.example.org', news)
app.router.routes.append(news_host)
```
URL lookups can include host parameters just like path parameters
```python
routes = [
Host("{subdomain}.example.org", name="sub", app=Router(routes=[
Mount("/users", name="users", routes=[
Route("/", user, name="user_list"),
Route("/{username}", user, name="user_detail")
])
]))
]
...
url = request.url_for("sub:users:user_detail", username=..., subdomain=...)
url = request.url_for("sub:users:user_list", subdomain=...)
```
## Route priority
Incoming paths are matched against each `Route` in order.
In cases where more that one route could match an incoming path, you should
take care to ensure that more specific routes are listed before general cases.
For example:
```python
# Don't do this: `/users/me` will never match incoming requests.
routes = [
Route('/users/{username}', user),
Route('/users/me', current_user),
]
# Do this: `/users/me` is tested first.
routes = [
Route('/users/me', current_user),
Route('/users/{username}', user),
]
```
## Working with Router instances
If you're working at a low-level you might want to use a plain `Router`
instance, rather that creating a `Starlette` application. This gives you
a lightweight ASGI application that just provides the application routing,
without wrapping it up in any middleware.
```python
app = Router(routes=[
Route('/', homepage),
Mount('/users', routes=[
Route('/', users, methods=['GET', 'POST']),
Route('/{username}', user),
])
])
```
## WebSocket Routing
When working with WebSocket endpoints, you should use `WebSocketRoute`
instead of the usual `Route`.
Path parameters, and reverse URL lookups for `WebSocketRoute` work the same
as HTTP `Route`, which can be found in the HTTP [Route](#http-routing) section above.
```python
from starlette.applications import Starlette
from starlette.routing import WebSocketRoute
async def websocket_index(websocket):
await websocket.accept()
await websocket.send_text("Hello, websocket!")
await websocket.close()
async def websocket_user(websocket):
name = websocket.path_params["name"]
await websocket.accept()
await websocket.send_text(f"Hello, {name}")
await websocket.close()
routes = [
WebSocketRoute("/", endpoint=websocket_index),
WebSocketRoute("/{name}", endpoint=websocket_user),
]
app = Starlette(routes=routes)
```
The `endpoint` argument can be one of:
* An async function, which accepts a single `websocket` argument.
* A class that implements the ASGI interface, such as Starlette's [WebSocketEndpoint](endpoints.md#websocketendpoint).
================================================
FILE: docs/schemas.md
================================================
Starlette supports generating API schemas, such as the widely used [OpenAPI
specification][openapi]. (Formerly known as "Swagger".)
Schema generation works by inspecting the routes on the application through
`app.routes`, and using the docstrings or other attributes on the endpoints
in order to determine a complete API schema.
Starlette is not tied to any particular schema generation or validation tooling,
but includes a simple implementation that generates OpenAPI schemas based on
the docstrings.
```python
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.schemas import SchemaGenerator
schemas = SchemaGenerator(
{"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}
)
def list_users(request):
"""
responses:
200:
description: A list of users.
examples:
[{"username": "tom"}, {"username": "lucy"}]
"""
raise NotImplementedError()
def create_user(request):
"""
responses:
200:
description: A user.
examples:
{"username": "tom"}
"""
raise NotImplementedError()
def openapi_schema(request):
return schemas.OpenAPIResponse(request=request)
routes = [
Route("/users", endpoint=list_users, methods=["GET"]),
Route("/users", endpoint=create_user, methods=["POST"]),
Route("/schema", endpoint=openapi_schema, include_in_schema=False)
]
app = Starlette(routes=routes)
```
We can now access an OpenAPI schema at the "/schema" endpoint.
You can generate the API Schema directly with `.get_schema(routes)`:
```python
schema = schemas.get_schema(routes=app.routes)
assert schema == {
"openapi": "3.0.0",
"info": {"title": "Example API", "version": "1.0"},
"paths": {
"/users": {
"get": {
"responses": {
200: {
"description": "A list of users.",
"examples": [{"username": "tom"}, {"username": "lucy"}],
}
}
},
"post": {
"responses": {
200: {"description": "A user.", "examples": {"username": "tom"}}
}
},
},
},
}
```
You might also want to be able to print out the API schema, so that you can
use tooling such as generating API documentation.
```python
if __name__ == '__main__':
assert sys.argv[-1] in ("run", "schema"), "Usage: example.py [run|schema]"
if sys.argv[-1] == "run":
uvicorn.run("example:app", host='0.0.0.0', port=8000)
elif sys.argv[-1] == "schema":
schema = schemas.get_schema(routes=app.routes)
print(yaml.dump(schema, default_flow_style=False))
```
### Third party packages
#### [starlette-apispec][starlette-apispec]
Easy APISpec integration for Starlette, which supports some object serialization libraries.
[openapi]: https://github.com/OAI/OpenAPI-Specification
[starlette-apispec]: https://github.com/Woile/starlette-apispec
================================================
FILE: docs/server-push.md
================================================
Starlette includes support for HTTP/2 and HTTP/3 server push, making it
possible to push resources to the client to speed up page load times.
### `Request.send_push_promise`
Used to initiate a server push for a resource. If server push is not available
this method does nothing.
Signature: `send_push_promise(path)`
* `path` - A string denoting the path of the resource.
```python
from starlette.applications import Starlette
from starlette.responses import HTMLResponse
from starlette.routing import Route, Mount
from starlette.staticfiles import StaticFiles
async def homepage(request):
"""
Homepage which uses server push to deliver the stylesheet.
"""
await request.send_push_promise("/static/style.css")
return HTMLResponse(
''
)
routes = [
Route("/", endpoint=homepage),
Mount("/static", StaticFiles(directory="static"), name="static")
]
app = Starlette(routes=routes)
```
================================================
FILE: docs/staticfiles.md
================================================
Starlette also includes a `StaticFiles` class for serving files in a given directory:
### StaticFiles
Signature: `StaticFiles(directory=None, packages=None, html=False, check_dir=True, follow_symlink=False)`
* `directory` - A string or [os.PathLike][pathlike] denoting a directory path.
* `packages` - A list of strings or list of tuples of strings of python packages.
* `html` - Run in HTML mode. Automatically loads `index.html` for directories if such file exist.
* `check_dir` - Ensure that the directory exists upon instantiation. Defaults to `True`.
* `follow_symlink` - A boolean indicating if symbolic links for files and directories should be followed. Defaults to `False`.
You can combine this ASGI application with Starlette's routing to provide
comprehensive static file serving.
```python
from starlette.applications import Starlette
from starlette.routing import Mount
from starlette.staticfiles import StaticFiles
routes = [
...
Mount('/static', app=StaticFiles(directory='static'), name="static"),
]
app = Starlette(routes=routes)
```
Static files will respond with "404 Not found" or "405 Method not allowed"
responses for requests which do not match. In HTML mode if `404.html` file
exists it will be shown as 404 response.
The `packages` option can be used to include "static" directories contained within
a python package. The Python "bootstrap4" package is an example of this.
```python
from starlette.applications import Starlette
from starlette.routing import Mount
from starlette.staticfiles import StaticFiles
routes=[
...
Mount('/static', app=StaticFiles(directory='static', packages=['bootstrap4']), name="static"),
]
app = Starlette(routes=routes)
```
By default `StaticFiles` will look for `statics` directory in each package,
you can change the default directory by specifying a tuple of strings.
```python
routes=[
...
Mount('/static', app=StaticFiles(packages=[('bootstrap4', 'static')]), name="static"),
]
```
You may prefer to include static files directly inside the "static" directory
rather than using Python packaging to include static files, but it can be useful
for bundling up reusable components.
[pathlike]: https://docs.python.org/3/library/os.html#os.PathLike
================================================
FILE: docs/templates.md
================================================
Starlette is not _strictly_ coupled to any particular templating engine, but
Jinja2 provides an excellent choice.
??? abstract "API Reference"
::: starlette.templating.Jinja2Templates
options:
parameter_headings: false
show_root_heading: true
heading_level: 3
filters:
- "__init__"
Starlette provides a simple way to get `jinja2` configured. This is probably
what you want to use by default.
```python
from starlette.applications import Starlette
from starlette.routing import Route, Mount
from starlette.templating import Jinja2Templates
from starlette.staticfiles import StaticFiles
templates = Jinja2Templates(directory='templates')
async def homepage(request):
return templates.TemplateResponse(request, 'index.html')
routes = [
Route('/', endpoint=homepage),
Mount('/static', StaticFiles(directory='static'), name='static')
]
app = Starlette(debug=True, routes=routes)
```
Note that the incoming `request` instance must be included as part of the
template context.
The Jinja2 template context will automatically include a `url_for` function,
so we can correctly hyperlink to other pages within the application.
For example, we can link to static files from within our HTML templates:
```html
```
If you want to use [custom filters][jinja2], you will need to update the `env`
property of `Jinja2Templates`:
```python
from commonmark import commonmark
from starlette.templating import Jinja2Templates
def marked_filter(text):
return commonmark(text)
templates = Jinja2Templates(directory='templates')
templates.env.filters['marked'] = marked_filter
```
## Using custom jinja2.Environment instance
Starlette also accepts a preconfigured [`jinja2.Environment`](https://jinja.palletsprojects.com/en/3.0.x/api/#api) instance.
```python
import jinja2
from starlette.templating import Jinja2Templates
env = jinja2.Environment(...)
templates = Jinja2Templates(env=env)
```
## Autoescape
When using the `directory` argument, Starlette enables autoescape by default for
`.html`, `.htm`, and `.xml` templates using [`jinja2.select_autoescape()`](https://jinja.palletsprojects.com/en/stable/api/#jinja2.select_autoescape).
This protects against Cross-Site Scripting (XSS) vulnerabilities by escaping
user-provided content before rendering it in the template. For example, if a user
submits `` as their name, it will be rendered as
`<script>alert('XSS')</script>` instead of being executed as JavaScript.
## Context processors
A context processor is a function that returns a dictionary to be merged into a template context.
Every function takes only one argument `request` and must return a dictionary to add to the context.
A common use case of template processors is to extend the template context with shared variables.
```python
import typing
from starlette.requests import Request
def app_context(request: Request) -> typing.Dict[str, typing.Any]:
return {'app': request.app}
```
### Registering context templates
Pass context processors to `context_processors` argument of the `Jinja2Templates` class.
```python
import typing
from starlette.requests import Request
from starlette.templating import Jinja2Templates
def app_context(request: Request) -> typing.Dict[str, typing.Any]:
return {'app': request.app}
templates = Jinja2Templates(
directory='templates', context_processors=[app_context]
)
```
!!! info
Asynchronous functions as context processors are not supported.
## Testing template responses
When using the test client, template responses include `.template` and `.context`
attributes.
```python
from starlette.testclient import TestClient
def test_homepage():
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
assert response.template.name == 'index.html'
assert "request" in response.context
```
## Asynchronous template rendering
Jinja2 supports async template rendering, however as a general rule
we'd recommend that you keep your templates free from logic that invokes
database lookups, or other I/O operations.
Instead we'd recommend that you ensure that your endpoints perform all I/O,
for example, strictly evaluate any database queries within the view and
include the final results in the context.
[jinja2]: https://jinja.palletsprojects.com/en/3.0.x/api/?highlight=environment#writing-filters
[pathlike]: https://docs.python.org/3/library/os.html#os.PathLike
================================================
FILE: docs/testclient.md
================================================
??? abstract "API Reference"
::: starlette.testclient.TestClient
options:
parameter_headings: false
show_bases: false
show_root_heading: true
heading_level: 3
filters:
- "__init__"
The test client allows you to make requests against your ASGI application,
using the `httpx` library.
```python
from starlette.responses import HTMLResponse
from starlette.testclient import TestClient
async def app(scope, receive, send):
assert scope['type'] == 'http'
response = HTMLResponse('Hello, world!')
await response(scope, receive, send)
def test_app():
client = TestClient(app)
response = client.get('/')
assert response.status_code == 200
```
The test client exposes the same interface as any other `httpx` session.
In particular, note that the calls to make a request are just standard
function calls, not awaitables.
You can use any of `httpx` standard API, such as authentication, session
cookies handling, or file uploads.
For example, to set headers on the TestClient you can do:
```python
client = TestClient(app)
# Set headers on the client for future requests
client.headers = {"Authorization": "..."}
response = client.get("/")
# Set headers for each request separately
response = client.get("/", headers={"Authorization": "..."})
```
And for example to send files with the TestClient:
```python
client = TestClient(app)
# Send a single file
with open("example.txt", "rb") as f:
response = client.post("/form", files={"file": f})
# Send multiple files
with open("example.txt", "rb") as f1:
with open("example.png", "rb") as f2:
files = {"file1": f1, "file2": ("filename", f2, "image/png")}
response = client.post("/form", files=files)
```
For more information you can check the `httpx` [documentation](https://www.python-httpx.org/advanced/).
By default the `TestClient` will raise any exceptions that occur in the
application. Occasionally you might want to test the content of 500 error
responses, rather than allowing client to raise the server exception. In this
case you should use `client = TestClient(app, raise_server_exceptions=False)`.
!!! note
If you want the `TestClient` to run the `lifespan` handler,
you will need to use the `TestClient` as a context manager. It will
not be triggered when the `TestClient` is instantiated. You can learn more about it
[here](lifespan.md#running-lifespan-in-tests).
### Change client address
By default, the TestClient will set the client host to `"testserver"` and the port to `50000`.
You can change the client address by setting the `client` attribute of the `TestClient` instance:
```python
client = TestClient(app, client=('localhost', 8000))
```
### Selecting the Async backend
`TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary).
These options are passed to `anyio.start_blocking_portal()`.
See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options)
for more information about the accepted backend options.
By default, `asyncio` is used with default options.
To run `Trio`, pass `backend="trio"`. For example:
```python
def test_app()
with TestClient(app, backend="trio") as client:
...
```
To run `asyncio` with `uvloop`, pass `backend_options={"use_uvloop": True}`. For example:
```python
def test_app()
with TestClient(app, backend_options={"use_uvloop": True}) as client:
...
```
### Testing WebSocket sessions
You can also test websocket sessions with the test client.
The `httpx` library will be used to build the initial handshake, meaning you
can use the same authentication options and other headers between both http and
websocket testing.
```python
from starlette.testclient import TestClient
from starlette.websockets import WebSocket
async def app(scope, receive, send):
assert scope['type'] == 'websocket'
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.send_text('Hello, world!')
await websocket.close()
def test_app():
client = TestClient(app)
with client.websocket_connect('/') as websocket:
data = websocket.receive_text()
assert data == 'Hello, world!'
```
The operations on session are standard function calls, not awaitables.
It's important to use the session within a context-managed `with` block. This
ensure that the background thread on which the ASGI application is properly
terminated, and that any exceptions that occur within the application are
always raised by the test client.
#### Establishing a test session
* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `httpx.get()`.
May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection.
`websocket_connect()` must be used as a context manager (in a `with` block).
!!! note
The `params` argument is not supported by `websocket_connect`. If you need to pass query arguments, hard code it
directly in the URL.
```python
with client.websocket_connect('/path?foo=bar') as websocket:
...
```
#### Sending data
* `.send_text(data)` - Send the given text to the application.
* `.send_bytes(data)` - Send the given bytes to the application.
* `.send_json(data, mode="text")` - Send the given data to the application. Use `mode="binary"` to send JSON over binary data frames.
#### Receiving data
* `.receive_text()` - Wait for incoming text sent by the application and return it.
* `.receive_bytes()` - Wait for incoming bytestring sent by the application and return it.
* `.receive_json(mode="text")` - Wait for incoming json data sent by the application and return it. Use `mode="binary"` to receive JSON over binary data frames.
May raise `starlette.websockets.WebSocketDisconnect`.
#### Closing the connection
* `.close(code=1000)` - Perform a client-side close of the websocket connection.
### Asynchronous tests
Sometimes you will want to do async things outside of your application.
For example, you might want to check the state of your database after calling your app
using your existing async database client/infrastructure.
For these situations, using `TestClient` is difficult because it creates it's own event loop and async
resources (like a database connection) often cannot be shared across event loops.
The simplest way to work around this is to just make your entire test async and use an async client, like [httpx.AsyncClient].
Here is an example of such a test:
```python
from httpx import AsyncClient, ASGITransport
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.requests import Request
from starlette.responses import PlainTextResponse
def hello(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello World!")
app = Starlette(routes=[Route("/", hello)])
# if you're using pytest, you'll need to to add an async marker like:
# @pytest.mark.anyio # using https://github.com/agronholm/anyio
# or install and configure pytest-asyncio (https://github.com/pytest-dev/pytest-asyncio)
async def test_app() -> None:
# note: you _must_ set `base_url` for relative urls like "/" to work
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as client:
r = await client.get("/")
assert r.status_code == 200
assert r.text == "Hello World!"
```
[httpx.AsyncClient]: https://www.python-httpx.org/advanced/#calling-into-python-web-apps
================================================
FILE: docs/third-party-packages.md
================================================
Starlette has a rapidly growing community of developers, building tools that integrate into Starlette, tools that depend on Starlette, etc.
Here are some of those third party packages:
## Plugins
### Apitally
GitHub |
Documentation
Analytics, request logging and monitoring for REST APIs built with Starlette (and other frameworks).
### Authlib
GitHub |
Documentation
The ultimate Python library in building OAuth and OpenID Connect clients and servers. Check out how to integrate with [Starlette](https://docs.authlib.org/en/latest/client/starlette.html).
### ChannelBox
GitHub Repository
ChannelBox is a lightweight solution for WebSocket broadcasting in ASGI applications.
It allows sending messages to named WebSocket channel groups and integrates with Starlette and FastAPI.
### Imia
GitHub
An authentication framework for Starlette with pluggable authenticators and login/logout flow.
### Mangum
GitHub
Serverless ASGI adapter for AWS Lambda & API Gateway.
### Nejma
GitHub
Manage and send messages to groups of channels using websockets.
Checkout nejma-chat, a simple chat application built using `nejma` and `starlette`.
### Scout APM
GitHub
An APM (Application Performance Monitoring) solution that can
instrument your application to find performance bottlenecks.
### SpecTree
GitHub
Generate OpenAPI spec document and validate request & response with Python annotations. Less boilerplate code(no need for YAML).
### Starlette APISpec
GitHub
Simple APISpec integration for Starlette.
Document your REST API built with Starlette by declaring OpenAPI (Swagger)
schemas in YAML format in your endpoint's docstrings.
### Starlette Compress
GitHub
Starlette-Compress is a fast and simple middleware for compressing responses in Starlette.
It adds ZStd, Brotli, and GZip compression support with sensible default configuration.
### Starlette Context
GitHub
Middleware for Starlette that allows you to store and access the context data of a request.
Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id.
### Starlette Cramjam
GitHub
A Starlette middleware that allows **brotli**, **gzip** and **deflate** compression algorithm with a minimal requirements.
### Starlette OAuth2 API
GitLab
A starlette middleware to add authentication and authorization through JWTs.
It relies solely on an auth provider to issue access and/or id tokens to clients.
### Starlette Prometheus
GitHub
A plugin for providing an endpoint that exposes [Prometheus](https://prometheus.io/) metrics based on its [official python client](https://github.com/prometheus/client_python).
### Starlette WTF
GitHub
A simple tool for integrating Starlette and WTForms. It is modeled on the excellent Flask-WTF library.
### Starlette-Login
GitHub |
Documentation
User session management for Starlette.
It handles the common tasks of logging in, logging out, and remembering your users' sessions over extended periods of time.
### Starsessions
GitHub
An alternate session support implementation with customizable storage backends.
### webargs-starlette
GitHub
Declarative request parsing and validation for Starlette, built on top
of [webargs](https://github.com/marshmallow-code/webargs).
Allows you to parse querystring, JSON, form, headers, and cookies using
type annotations.
### DecoRouter
GitHub
FastAPI style routing for Starlette.
Allows you to use decorators to generate routing tables.
### Starception
GitHub
Beautiful exception page for Starlette apps.
### Starlette-Admin
GitHub |
Documentation
Simple and extensible admin interface framework.
Built with [Tabler](https://tabler.io/) and [Datatables](https://datatables.net/), it allows you
to quickly generate fully customizable admin interface for your models. You can export your data to many formats (*CSV*, *PDF*,
*Excel*, etc), filter your data with complex query including `AND` and `OR` conditions, upload files, ...
### Vellox
GitHub
Serverless ASGI adapter for GCP Cloud Functions.
## Starlette Bridge
GitHub |
Documentation
With the deprecation of `on_startup` and `on_shutdown`, Starlette Bridge makes sure you can still
use the old ways of declaring events with a particularity that internally, in fact, creates the
`lifespan` for you. This way backwards compatibility is assured for the existing packages out there
while maintaining the integrity of the newly `lifespan` events of `Starlette`.
## Frameworks
### FastAPI
GitHub |
Documentation
High performance, easy to learn, fast to code, ready for production web API framework.
Inspired by **APIStar**'s previous server system with type declarations for route parameters, based on the OpenAPI specification version 3.0.0+ (with JSON Schema), powered by **Pydantic** for the data handling.
### Flama
GitHub |
Documentation
Flama is a **data-science oriented framework** to rapidly build modern and robust **machine learning** (ML) APIs. The main aim of the framework is to make ridiculously simple the deployment of ML APIs. With Flama, data scientists can now quickly turn their ML models into asynchronous, auto-documented APIs with just a single line of code. All in just few seconds!
Flama comes with an intuitive CLI, and provides an easy-to-learn philosophy to speed up the building of **highly performant** GraphQL, REST, and ML APIs. Besides, it comprises an ideal solution for the development of asynchronous and **production-ready** services, offering **automatic deployment** for ML models.
### Greppo
GitHub |
Documentation
A Python framework for building geospatial dashboards and web-applications.
Greppo is an open-source Python framework that makes it easy to build geospatial dashboards and web-applications. It provides a toolkit to quickly integrate data, algorithms, visualizations and UI for interactivity. It provides APIs to the update the variables in the backend, recompute the logic, and reflect the changes in the frontend (data mutation hook).
### Responder
GitHub |
Documentation
Async web service framework. Some Features: flask-style route expression,
yaml support, OpenAPI schema generation, background tasks, graphql.
### Starlette-apps
Roll your own framework with a simple app system, like [Django-GDAPS](https://gdaps.readthedocs.io/en/latest/) or [CakePHP](https://cakephp.org/).
GitHub
### Dark Star
A simple framework to help minimise the code needed to get HTML to the browser. Changes your file paths into Starlette routes and puts your view code right next to your template. Includes support for [htmx](https://htmx.org) to help enhance your frontend.
DocsGitHub
### Xpresso
A flexible and extendable web framework built on top of Starlette, Pydantic and [di](https://github.com/adriangb/di).
GitHub |
Documentation
### Ellar
GitHub |
Documentation
Ellar is an ASGI web framework for building fast, efficient and scalable RESTAPIs and server-side applications. It offers a high level of abstraction in building server-side applications and combines elements of OOP (Object Oriented Programming), and FP (Functional Programming) - Inspired by Nestjs.
It is built on 3 core libraries **Starlette**, **Pydantic**, and **injector**.
### Apiman
An extension to integrate Swagger/OpenAPI document easily for Starlette project and provide [SwaggerUI](http://swagger.io/swagger-ui/) and [RedocUI](https://rebilly.github.io/ReDoc/).
GitHub
### Starlette-Babel
Provides translations, localization, and timezone support via Babel integration.
GitHub
### Starlette-StaticResources
GitHub
Allows mounting [package resources](https://docs.python.org/3/library/importlib.resources.html#module-importlib.resources) for static data, similar to [StaticFiles](staticfiles.md).
### Sentry
GitHub |
Documentation
Sentry is a software error detection tool. It offers actionable insights for resolving performance issues and errors, allowing users to diagnose, fix, and optimize Python debugging. Additionally, it integrates seamlessly with Starlette for Python application development. Sentry's capabilities include error tracking, performance insights, contextual information, and alerts/notifications.
### Shiny
GitHub |
Documentation
Leveraging Starlette and asyncio, Shiny allows developers to create effortless Python web applications using the power of reactive programming. Shiny eliminates the hassle of manual state management, automatically determining the best execution path for your app at runtime while simultaneously minimizing re-rendering. This means that Shiny can support everything from the simplest dashboard to full-featured web apps.
================================================
FILE: docs/threadpool.md
================================================
# Thread Pool
Starlette uses a thread pool in several scenarios to avoid blocking the event loop:
- When you create a synchronous endpoint using `def` instead of `async def`
- When serving files with [`FileResponse`](responses.md#fileresponse)
- When handling file uploads with [`UploadFile`](requests.md#request-files)
- When running synchronous background tasks with [`BackgroundTask`](background.md)
- And some other scenarios that may not be documented...
Starlette will run your code in a thread pool to avoid blocking the event loop.
This applies for endpoint functions and background tasks you create, but also for internal Starlette code.
To be more precise, Starlette uses `anyio.to_thread.run_sync` to run the synchronous code.
## Concurrency Limitations
The default thread pool size is only 40 _tokens_. This means that only 40 threads can run at the same time.
This limit is shared with other libraries: for example FastAPI also uses `anyio` to run sync dependencies, which also uses up thread capacity.
If you need to run more threads, you can increase the number of _tokens_:
```py
import anyio.to_thread
limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = 100
```
The above code will increase the number of _tokens_ to 100.
Increasing the number of threads may have a performance and memory impact, so be careful when doing so.
================================================
FILE: docs/websockets.md
================================================
Starlette includes a `WebSocket` class that fulfils a similar role
to the HTTP request, but that allows sending and receiving data on a websocket.
### WebSocket
Signature: `WebSocket(scope, receive=None, send=None)`
```python
from starlette.websockets import WebSocket
async def app(scope, receive, send):
websocket = WebSocket(scope=scope, receive=receive, send=send)
await websocket.accept()
await websocket.send_text('Hello, world!')
await websocket.close()
```
WebSockets present a mapping interface, so you can use them in the same
way as a `scope`.
For instance: `websocket['path']` will return the ASGI path.
#### URL
The websocket URL is accessed as `websocket.url`.
The property is actually a subclass of `str`, and also exposes all the
components that can be parsed out of the URL.
For example: `websocket.url.path`, `websocket.url.port`, `websocket.url.scheme`.
#### Headers
Headers are exposed as an immutable, case-insensitive, multi-dict.
For example: `websocket.headers['sec-websocket-version']`
#### Query Parameters
Query parameters are exposed as an immutable multi-dict.
For example: `websocket.query_params['search']`
#### Path Parameters
Router path parameters are exposed as a dictionary interface.
For example: `websocket.path_params['username']`
### Accepting the connection
* `await websocket.accept(subprotocol=None, headers=None)`
### Sending data
* `await websocket.send_text(data)`
* `await websocket.send_bytes(data)`
* `await websocket.send_json(data)`
JSON messages default to being sent over text data frames, from version 0.10.0 onwards.
Use `websocket.send_json(data, mode="binary")` to send JSON over binary data frames.
### Receiving data
* `await websocket.receive_text()`
* `await websocket.receive_bytes()`
* `await websocket.receive_json()`
May raise `starlette.websockets.WebSocketDisconnect()`.
JSON messages default to being received over text data frames, from version 0.10.0 onwards.
Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary data frames.
### Iterating data
* `websocket.iter_text()`
* `websocket.iter_bytes()`
* `websocket.iter_json()`
Similar to `receive_text`, `receive_bytes`, and `receive_json` but returns an
async iterator.
```python hl_lines="7-8"
from starlette.websockets import WebSocket
async def app(scope, receive, send):
websocket = WebSocket(scope=scope, receive=receive, send=send)
await websocket.accept()
async for message in websocket.iter_text():
await websocket.send_text(f"Message text was: {message}")
await websocket.close()
```
When `starlette.websockets.WebSocketDisconnect` is raised, the iterator will exit.
### Closing the connection
* `await websocket.close(code=1000, reason=None)`
### Sending and receiving messages
If you need to send or receive raw ASGI messages then you should use
`websocket.send()` and `websocket.receive()` rather than using the raw `send` and
`receive` callables. This will ensure that the websocket's state is kept
correctly updated.
* `await websocket.send(message)`
* `await websocket.receive()`
### Send Denial Response
If you call `websocket.close()` before calling `websocket.accept()` then
the server will automatically send a HTTP 403 error to the client.
If you want to send a different error response, you can use the
`websocket.send_denial_response()` method. This will send the response
and then close the connection.
* `await websocket.send_denial_response(response)`
This requires the ASGI server to support the WebSocket Denial Response
extension. If it is not supported a `RuntimeError` will be raised.
In the context of `Starlette`, you can also use the `HTTPException` to achieve the same result.
```python
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.routing import WebSocketRoute
from starlette.websockets import WebSocket
def is_authorized(subprotocols: list[str]):
if len(subprotocols) != 2:
return False
if subprotocols[0] != "Authorization":
return False
# Here we are hard coding the token, in a real application you would validate the token
# against a database or an external service.
if subprotocols[1] != "token":
return False
return True
async def websocket_endpoint(websocket: WebSocket):
subprotocols = websocket.scope["subprotocols"]
if not is_authorized(subprotocols):
raise HTTPException(status_code=401, detail="Unauthorized")
await websocket.accept("Authorization")
await websocket.send_text("Hello, world!")
await websocket.close()
app = Starlette(debug=True, routes=[WebSocketRoute("/ws", websocket_endpoint)])
```
================================================
FILE: mkdocs.yml
================================================
site_name: Starlette
site_description: The little ASGI library that shines.
site_url: https://starlette.dev
repo_name: Kludex/starlette
repo_url: https://github.com/Kludex/starlette
edit_uri: edit/main/docs/
strict: true
theme:
name: "material"
custom_dir: docs/overrides
palette:
- scheme: "default"
media: "(prefers-color-scheme: light)"
toggle:
icon: "material/lightbulb"
name: "Switch to dark mode"
- scheme: "slate"
media: "(prefers-color-scheme: dark)"
primary: "blue"
toggle:
icon: "material/lightbulb-outline"
name: "Switch to light mode"
icon:
repo: fontawesome/brands/github
features:
- content.code.copy
- toc.follow
nav:
- Introduction: "index.md"
- Features:
- Applications: "applications.md"
- Requests: "requests.md"
- Responses: "responses.md"
- WebSockets: "websockets.md"
- Routing: "routing.md"
- Endpoints: "endpoints.md"
- Middleware: "middleware.md"
- Static Files: "staticfiles.md"
- Templates: "templates.md"
- Database: "database.md"
- GraphQL: "graphql.md"
- Authentication: "authentication.md"
- API Schemas: "schemas.md"
- Lifespan: "lifespan.md"
- Background Tasks: "background.md"
- Server Push: "server-push.md"
- Exceptions: "exceptions.md"
- Configuration: "config.md"
- Thread Pool: "threadpool.md"
- Test Client: "testclient.md"
- Release Notes: "release-notes.md"
- Community:
- Third Party Packages: "third-party-packages.md"
- Contributing: "contributing.md"
extra_css:
- css/custom.css
extra_javascript:
- js/custom.js
extra:
analytics:
provider: google
property: G-Z37GTYBR6M
social:
- icon: fontawesome/brands/github-alt
link: https://github.com/Kludex/starlette
- icon: fontawesome/brands/discord
link: https://discord.com/invite/RxKUF5JuHs
- icon: fontawesome/brands/twitter
link: https://x.com/marcelotryle
- icon: fontawesome/brands/linkedin
link: https://www.linkedin.com/in/marcelotryle
- icon: fontawesome/solid/globe
link: https://fastapiexpert.com
markdown_extensions:
- attr_list
- admonition
- pymdownx.highlight
- pymdownx.superfences
- pymdownx.details
- pymdownx.tabbed:
alternate_style: true
- pymdownx.emoji:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- pymdownx.tasklist:
custom_checkbox: true
watch:
- starlette
plugins:
- search
- mkdocstrings:
handlers:
python:
options:
docstring_section_style: list
show_root_toc_entry: false
members_order: source
separate_signature: true
filters: ["!^_"]
docstring_options:
ignore_init_summary: true
merge_init_into_class: true
parameter_headings: true
show_signature_annotations: true
show_source: false
signature_crossrefs: true
inventories:
- url: https://docs.python.org/3/objects.inv
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "starlette"
dynamic = ["version"]
description = "The little ASGI library that shines."
readme = "README.md"
license = "BSD-3-Clause"
license-files = ["LICENSE.md"]
requires-python = ">=3.10"
authors = [
{ name = "Tom Christie", email = "tom@tomchristie.com" }
]
maintainers = [
{ name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" },
]
classifiers = [
"Development Status :: 3 - Alpha",
"Environment :: Web Environment",
"Framework :: AnyIO",
"Intended Audience :: Developers",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Internet :: WWW/HTTP",
]
dependencies = [
"anyio>=3.6.2,<5",
"typing_extensions>=4.10.0; python_version < '3.13'",
]
[project.optional-dependencies]
full = [
"itsdangerous",
"jinja2",
"python-multipart>=0.0.18",
"pyyaml",
"httpx>=0.27.0,<0.29.0",
]
[dependency-groups]
dev = [
# We add starlette[full] so `uv sync` considers the extras.
"starlette[full]",
"coverage>=7.8.2",
"importlib-metadata==8.7.1",
"mypy==1.16.1",
"ruff==0.15.4",
"types-PyYAML==6.0.12.20250516",
"pytest==9.0.2",
"trio==0.33.0",
# Check dist
"twine==6.2.0",
]
docs = [
"black==26.3.1",
"mkdocstrings>=1.0.2",
"mkdocstrings-python>=2.0.1",
"zensical>=0.0.19",
]
[tool.uv]
default-groups = ["dev", "docs"]
required-version = ">=0.8.6"
[project.urls]
Homepage = "https://github.com/Kludex/starlette"
Documentation = "https://starlette.dev/"
Changelog = "https://starlette.dev/release-notes/"
Funding = "https://github.com/sponsors/Kludex"
Source = "https://github.com/Kludex/starlette"
[tool.hatch.version]
path = "starlette/__init__.py"
[tool.ruff]
line-length = 120
[tool.ruff.lint]
select = [
"E", # https://docs.astral.sh/ruff/rules/#error-e
"F", # https://docs.astral.sh/ruff/rules/#pyflakes-f
"I", # https://docs.astral.sh/ruff/rules/#isort-i
"FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa
"UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up
"RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
]
ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/
[tool.ruff.lint.isort]
combine-as-imports = true
[tool.mypy]
strict = true
[[tool.mypy.overrides]]
module = "starlette.testclient.*"
implicit_optional = true
[tool.pytest.ini_options]
addopts = "-rXs --strict-config --strict-markers"
xfail_strict = true
filterwarnings = [
# Turn warnings that aren't filtered into exceptions
"error",
"ignore: run_until_first_complete is deprecated and will be removed in a future version.:DeprecationWarning",
"ignore: starlette.middleware.wsgi is deprecated and will be removed in a future release.*:DeprecationWarning",
"ignore: Async generator 'starlette.requests.Request.stream' was garbage collected before it had been exhausted.*:ResourceWarning",
"ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning",
]
[tool.coverage.run]
branch = true
source_pkgs = ["starlette", "tests"]
[tool.coverage.report]
exclude_also = [
"@overload",
"raise NotImplementedError",
]
================================================
FILE: scripts/README.md
================================================
# Development Scripts
* `scripts/install` - Install dependencies in a virtual environment.
* `scripts/test` - Run the test suite.
* `scripts/lint` - Run the automated code linting/formatting tools.
* `scripts/check` - Run the code linting, checking that it passes.
* `scripts/coverage` - Check that code coverage is complete.
* `scripts/build` - Build source and wheel packages.
Styled after GitHub's ["Scripts to Rule Them All"](https://github.com/github/scripts-to-rule-them-all).
================================================
FILE: scripts/build
================================================
#!/bin/sh -e
set -x
uv build
uv run twine check dist/*
uv run zensical build --clean
================================================
FILE: scripts/check
================================================
#!/bin/sh -e
export SOURCE_FILES="starlette tests"
set -x
./scripts/sync-version
uv run ruff format --check --diff $SOURCE_FILES
uv run mypy $SOURCE_FILES
uv run ruff check $SOURCE_FILES
================================================
FILE: scripts/coverage
================================================
#!/bin/sh -e
set -x
uv run coverage report --show-missing --skip-covered --fail-under=100
================================================
FILE: scripts/docs
================================================
#!/bin/sh -e
set -x
uv run zensical serve
================================================
FILE: scripts/install
================================================
#!/bin/sh -e
set -x
uv sync --frozen
================================================
FILE: scripts/lint
================================================
#!/bin/sh -e
export SOURCE_FILES="starlette tests"
set -x
uv run ruff format $SOURCE_FILES
uv run ruff check --fix $SOURCE_FILES
================================================
FILE: scripts/sync-version
================================================
#!/bin/sh -e
SEMVER_REGEX="([0-9]+)\.([0-9]+)\.([0-9]+)(-([0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*))?(\+[0-9A-Za-z-]+)?"
CHANGELOG_VERSION=$(grep -o -E $SEMVER_REGEX docs/release-notes.md | head -1)
VERSION=$(grep -o -E $SEMVER_REGEX starlette/__init__.py | head -1)
if [ "$CHANGELOG_VERSION" != "$VERSION" ]; then
echo "Version in changelog does not match version in starlette/__init__.py!"
exit 1
fi
================================================
FILE: scripts/test
================================================
#!/bin/sh
set -ex
if [ -z $GITHUB_ACTIONS ]; then
scripts/check
fi
uv run coverage run -m pytest $@
if [ -z $GITHUB_ACTIONS ]; then
scripts/coverage
fi
================================================
FILE: starlette/__init__.py
================================================
__version__ = "1.0.0rc1"
================================================
FILE: starlette/_exception_handler.py
================================================
from __future__ import annotations
from typing import Any
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
ExceptionHandlers = dict[Any, ExceptionHandler]
StatusHandlers = dict[int, ExceptionHandler]
def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
return None
def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
exception_handlers: ExceptionHandlers
status_handlers: StatusHandlers
try:
exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
response_started = False
async def sender(message: Message) -> None:
nonlocal response_started
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await app(scope, receive, sender)
except Exception as exc:
handler = None
if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)
if handler is None:
handler = _lookup_exception_handler(exception_handlers, exc)
if handler is None:
raise exc
if response_started:
raise RuntimeError("Caught handled exception, but response already started.") from exc
if is_async_callable(handler):
response = await handler(conn, exc)
else:
response = await run_in_threadpool(handler, conn, exc)
if response is not None:
await response(scope, receive, sender)
return wrapped_app
================================================
FILE: starlette/_utils.py
================================================
from __future__ import annotations
import functools
import sys
from collections.abc import Awaitable, Callable, Generator
from contextlib import AbstractAsyncContextManager, contextmanager
from typing import Any, Generic, Protocol, TypeVar, overload
from starlette.types import Scope
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
from typing import TypeIs
else: # pragma: no cover
from asyncio import iscoroutinefunction
from typing_extensions import TypeIs
has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
try:
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
except ImportError:
has_exceptiongroups = False
T = TypeVar("T")
AwaitableCallable = Callable[..., Awaitable[T]]
@overload
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
@overload
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
def is_async_callable(obj: Any) -> Any:
while isinstance(obj, functools.partial):
obj = obj.func
return iscoroutinefunction(obj) or (callable(obj) and iscoroutinefunction(obj.__call__))
T_co = TypeVar("T_co", covariant=True)
class AwaitableOrContextManager(
Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]
): ... # pragma: no branch
class SupportsAsyncClose(Protocol):
async def close(self) -> None: ... # pragma: no cover
SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
__slots__ = ("aw", "entered")
def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
self.aw = aw
def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
return self.aw.__await__()
async def __aenter__(self) -> SupportsAsyncCloseType:
self.entered = await self.aw
return self.entered
async def __aexit__(self, *args: Any) -> None | bool:
await self.entered.close()
return None
@contextmanager
def collapse_excgroups() -> Generator[None, None, None]:
try:
yield
except BaseException as exc:
if has_exceptiongroups: # pragma: no cover
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]
raise exc
def get_route_path(scope: Scope) -> str:
path: str = scope["path"]
root_path = scope.get("root_path", "")
if not root_path:
return path
if not path.startswith(root_path):
return path
if path == root_path:
return ""
if path[len(root_path)] == "/":
return path[len(root_path) :]
return path
================================================
FILE: starlette/applications.py
================================================
from __future__ import annotations
from collections.abc import Awaitable, Callable, Mapping, Sequence
from typing import Any, ParamSpec, TypeVar
from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
AppType = TypeVar("AppType", bound="Starlette")
P = ParamSpec("P")
class Starlette:
"""Creates an Starlette application."""
def __init__(
self: AppType,
debug: bool = False,
routes: Sequence[BaseRoute] | None = None,
middleware: Sequence[Middleware] | None = None,
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
lifespan: Lifespan[AppType] | None = None,
) -> None:
"""Initializes the application.
Parameters:
debug: Boolean indicating if debug tracebacks should be returned on errors.
routes: A list of routes to serve incoming HTTP and WebSocket requests.
middleware: A list of middleware to run for every request. A starlette
application will always automatically include two middleware classes.
`ServerErrorMiddleware` is added as the very outermost middleware, to handle
any uncaught errors occurring anywhere in the entire stack.
`ExceptionMiddleware` is added as the very innermost middleware, to deal
with handled exception cases occurring in the routing or endpoints.
exception_handlers: A mapping of either integer status codes,
or exception class types onto callables which handle the exceptions.
Exception handler callables should be of the form
`handler(request, exc) -> response` and may be either standard functions, or
async functions.
lifespan: A lifespan context function, which can be used to perform
startup and shutdown tasks. This is a newer style that replaces the
`on_startup` and `on_shutdown` handlers. Use one or the other, not both.
"""
self.debug = debug
self.state = State()
self.router = Router(routes, lifespan=lifespan)
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: ASGIApp | None = None
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
exception_handlers: dict[Any, ExceptionHandler] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
)
app = self.router
for cls, args, kwargs in reversed(middleware):
app = cls(app, *args, **kwargs)
return app
@property
def routes(self) -> list[BaseRoute]:
return self.router.routes
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
return self.router.url_path_for(name, **path_params)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
if self.middleware_stack is None:
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
self.router.mount(path, app=app, name=name) # pragma: no cover
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover
def add_middleware(self, middleware_class: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))
def add_exception_handler(
self,
exc_class_or_status_code: int | type[Exception],
handler: ExceptionHandler,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler
def add_route(
self,
path: str,
route: Callable[[Request], Awaitable[Response] | Response],
methods: list[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
================================================
FILE: starlette/authentication.py
================================================
from __future__ import annotations
import functools
import inspect
from collections.abc import Callable, Sequence
from typing import Any, ParamSpec
from urllib.parse import urlencode
from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket
_P = ParamSpec("_P")
def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool:
for scope in scopes:
if scope not in conn.auth.scopes:
return False
return True
def requires(
scopes: str | Sequence[str],
status_code: int = 403,
redirect: str | None = None,
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(
func: Callable[_P, Any],
) -> Callable[_P, Any]:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
type_ = parameter.name
break
else:
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
await websocket.close()
else:
await func(*args, **kwargs)
return websocket_wrapper
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return await func(*args, **kwargs)
return async_wrapper
else:
# Handle sync request/response functions.
@functools.wraps(func)
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
request = kwargs.get("request", args[idx] if idx < len(args) else None)
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
orig_request_qparam = urlencode({"next": str(request.url)})
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
return RedirectResponse(url=next_url, status_code=303)
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
return sync_wrapper
return decorator
class AuthenticationError(Exception):
pass
class AuthenticationBackend:
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
raise NotImplementedError() # pragma: no cover
class AuthCredentials:
def __init__(self, scopes: Sequence[str] | None = None):
self.scopes = [] if scopes is None else list(scopes)
class BaseUser:
@property
def is_authenticated(self) -> bool:
raise NotImplementedError() # pragma: no cover
@property
def display_name(self) -> str:
raise NotImplementedError() # pragma: no cover
@property
def identity(self) -> str:
raise NotImplementedError() # pragma: no cover
class SimpleUser(BaseUser):
def __init__(self, username: str) -> None:
self.username = username
@property
def is_authenticated(self) -> bool:
return True
@property
def display_name(self) -> str:
return self.username
class UnauthenticatedUser(BaseUser):
@property
def is_authenticated(self) -> bool:
return False
@property
def display_name(self) -> str:
return ""
================================================
FILE: starlette/background.py
================================================
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import Any, ParamSpec
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
P = ParamSpec("P")
class BackgroundTask:
def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
self.is_async = is_async_callable(func)
async def __call__(self) -> None:
if self.is_async:
await self.func(*self.args, **self.kwargs)
else:
await run_in_threadpool(self.func, *self.args, **self.kwargs)
class BackgroundTasks(BackgroundTask):
def __init__(self, tasks: Sequence[BackgroundTask] | None = None):
self.tasks = list(tasks) if tasks else []
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)
async def __call__(self) -> None:
for task in self.tasks:
await task()
================================================
FILE: starlette/concurrency.py
================================================
from __future__ import annotations
import functools
import warnings
from collections.abc import AsyncIterator, Callable, Coroutine, Iterable, Iterator
from typing import ParamSpec, TypeVar
import anyio.to_thread
P = ParamSpec("P")
T = TypeVar("T")
async def run_until_first_complete(*args: tuple[Callable, dict]) -> None: # type: ignore[type-arg]
warnings.warn(
"run_until_first_complete is deprecated and will be removed in a future version.",
DeprecationWarning,
)
async with anyio.create_task_group() as task_group:
async def run(func: Callable[[], Coroutine]) -> None: # type: ignore[type-arg]
await func()
task_group.cancel_scope.cancel()
for func, kwargs in args:
task_group.start_soon(run, functools.partial(func, **kwargs))
async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func = functools.partial(func, *args, **kwargs)
return await anyio.to_thread.run_sync(func)
class _StopIteration(Exception):
pass
def _next(iterator: Iterator[T]) -> T:
# We can't raise `StopIteration` from within the threadpool iterator
# and catch it outside that context, so we coerce them into a different
# exception type.
try:
return next(iterator)
except StopIteration:
raise _StopIteration
async def iterate_in_threadpool(
iterator: Iterable[T],
) -> AsyncIterator[T]:
as_iterator = iter(iterator)
while True:
try:
yield await anyio.to_thread.run_sync(_next, as_iterator)
except _StopIteration:
break
================================================
FILE: starlette/config.py
================================================
from __future__ import annotations
import os
import warnings
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from pathlib import Path
from typing import Any, TypeVar, overload
class undefined:
pass
class EnvironError(Exception):
pass
class Environ(MutableMapping[str, str]):
def __init__(self, environ: MutableMapping[str, str] = os.environ):
self._environ = environ
self._has_been_read: set[str] = set()
def __getitem__(self, key: str) -> str:
self._has_been_read.add(key)
return self._environ.__getitem__(key)
def __setitem__(self, key: str, value: str) -> None:
if key in self._has_been_read:
raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
self._environ.__setitem__(key, value)
def __delitem__(self, key: str) -> None:
if key in self._has_been_read:
raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
self._environ.__delitem__(key)
def __iter__(self) -> Iterator[str]:
return iter(self._environ)
def __len__(self) -> int:
return len(self._environ)
environ = Environ()
T = TypeVar("T")
class Config:
def __init__(
self,
env_file: str | Path | None = None,
environ: Mapping[str, str] = environ,
env_prefix: str = "",
encoding: str = "utf-8",
) -> None:
self.environ = environ
self.env_prefix = env_prefix
self.file_values: dict[str, str] = {}
if env_file is not None:
if not os.path.isfile(env_file):
warnings.warn(f"Config file '{env_file}' not found.")
else:
self.file_values = self._read_file(env_file, encoding)
@overload
def __call__(self, key: str, *, default: None) -> str | None: ...
@overload
def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ...
@overload
def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ...
@overload
def __call__(
self,
key: str,
cast: Callable[[Any], T] = ...,
default: Any = ...,
) -> T: ...
@overload
def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
def __call__(
self,
key: str,
cast: Callable[[Any], Any] | None = None,
default: Any = undefined,
) -> Any:
return self.get(key, cast, default)
def get(
self,
key: str,
cast: Callable[[Any], Any] | None = None,
default: Any = undefined,
) -> Any:
key = self.env_prefix + key
if key in self.environ:
value = self.environ[key]
return self._perform_cast(key, value, cast)
if key in self.file_values:
value = self.file_values[key]
return self._perform_cast(key, value, cast)
if default is not undefined:
return self._perform_cast(key, default, cast)
raise KeyError(f"Config '{key}' is missing, and has no default.")
def _read_file(self, file_name: str | Path, encoding: str) -> dict[str, str]:
file_values: dict[str, str] = {}
with open(file_name, encoding=encoding) as input_file:
for line in input_file.readlines():
line = line.strip()
if "=" in line and not line.startswith("#"):
key, value = line.split("=", 1)
key = key.strip()
value = value.strip().strip("\"'")
file_values[key] = value
return file_values
def _perform_cast(
self,
key: str,
value: Any,
cast: Callable[[Any], Any] | None = None,
) -> Any:
if cast is None or value is None:
return value
elif cast is bool and isinstance(value, str):
mapping = {"true": True, "1": True, "false": False, "0": False}
value = value.lower()
if value not in mapping:
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
return mapping[value]
try:
return cast(value)
except (TypeError, ValueError):
raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")
================================================
FILE: starlette/convertors.py
================================================
from __future__ import annotations
import math
import uuid
from typing import Any, ClassVar, Generic, TypeVar
T = TypeVar("T")
class Convertor(Generic[T]):
regex: ClassVar[str] = ""
def convert(self, value: str) -> T:
raise NotImplementedError() # pragma: no cover
def to_string(self, value: T) -> str:
raise NotImplementedError() # pragma: no cover
class StringConvertor(Convertor[str]):
regex = "[^/]+"
def convert(self, value: str) -> str:
return value
def to_string(self, value: str) -> str:
value = str(value)
assert "/" not in value, "May not contain path separators"
assert value, "Must not be empty"
return value
class PathConvertor(Convertor[str]):
regex = ".*"
def convert(self, value: str) -> str:
return str(value)
def to_string(self, value: str) -> str:
return str(value)
class IntegerConvertor(Convertor[int]):
regex = "[0-9]+"
def convert(self, value: str) -> int:
return int(value)
def to_string(self, value: int) -> str:
value = int(value)
assert value >= 0, "Negative integers are not supported"
return str(value)
class FloatConvertor(Convertor[float]):
regex = r"[0-9]+(\.[0-9]+)?"
def convert(self, value: str) -> float:
return float(value)
def to_string(self, value: float) -> str:
value = float(value)
assert value >= 0.0, "Negative floats are not supported"
assert not math.isnan(value), "NaN values are not supported"
assert not math.isinf(value), "Infinite values are not supported"
return ("%0.20f" % value).rstrip("0").rstrip(".")
class UUIDConvertor(Convertor[uuid.UUID]):
regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}"
def convert(self, value: str) -> uuid.UUID:
return uuid.UUID(value)
def to_string(self, value: uuid.UUID) -> str:
return str(value)
CONVERTOR_TYPES: dict[str, Convertor[Any]] = {
"str": StringConvertor(),
"path": PathConvertor(),
"int": IntegerConvertor(),
"float": FloatConvertor(),
"uuid": UUIDConvertor(),
}
def register_url_convertor(key: str, convertor: Convertor[Any]) -> None:
CONVERTOR_TYPES[key] = convertor
================================================
FILE: starlette/datastructures.py
================================================
from __future__ import annotations
from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView
from shlex import shlex
from typing import (
Any,
BinaryIO,
NamedTuple,
TypeVar,
cast,
)
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
from starlette.concurrency import run_in_threadpool
from starlette.types import Scope
class Address(NamedTuple):
host: str
port: int
_KeyType = TypeVar("_KeyType")
# Mapping keys are invariant but their values are covariant since
# you can only read them
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
_CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
class URL:
def __init__(
self,
url: str = "",
scope: Scope | None = None,
**components: Any,
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
assert not components, 'Cannot set both "scope" and "**components".'
scheme = scope.get("scheme", "http")
server = scope.get("server", None)
path = scope["path"]
query_string = scope.get("query_string", b"")
host_header = None
for key, value in scope["headers"]:
if key == b"host":
host_header = value.decode("latin-1")
break
if host_header is not None:
url = f"{scheme}://{host_header}{path}"
elif server is None:
url = path
else:
host, port = server
default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
if port == default_port:
url = f"{scheme}://{host}{path}"
else:
url = f"{scheme}://{host}:{port}{path}"
if query_string:
url += "?" + query_string.decode()
elif components:
assert not url, 'Cannot set both "url" and "**components".'
url = URL("").replace(**components).components.geturl()
self._url = url
@property
def components(self) -> SplitResult:
if not hasattr(self, "_components"):
self._components = urlsplit(self._url)
return self._components
@property
def scheme(self) -> str:
return self.components.scheme
@property
def netloc(self) -> str:
return self.components.netloc
@property
def path(self) -> str:
return self.components.path
@property
def query(self) -> str:
return self.components.query
@property
def fragment(self) -> str:
return self.components.fragment
@property
def username(self) -> None | str:
return self.components.username
@property
def password(self) -> None | str:
return self.components.password
@property
def hostname(self) -> None | str:
return self.components.hostname
@property
def port(self) -> int | None:
return self.components.port
@property
def is_secure(self) -> bool:
return self.scheme in ("https", "wss")
def replace(self, **kwargs: Any) -> URL:
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
password = kwargs.pop("password", self.password)
if hostname is None:
netloc = self.netloc
_, _, hostname = netloc.rpartition("@")
if hostname[-1] != "]":
hostname = hostname.rsplit(":", 1)[0]
netloc = hostname
if port is not None:
netloc += f":{port}"
if username is not None:
userpass = username
if password is not None:
userpass += f":{password}"
netloc = f"{userpass}@{netloc}"
kwargs["netloc"] = netloc
components = self.components._replace(**kwargs)
return self.__class__(components.geturl())
def include_query_params(self, **kwargs: Any) -> URL:
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
params.update({str(key): str(value) for key, value in kwargs.items()})
query = urlencode(params.multi_items())
return self.replace(query=query)
def replace_query_params(self, **kwargs: Any) -> URL:
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
return self.replace(query=query)
def remove_query_params(self, keys: str | Sequence[str]) -> URL:
if isinstance(keys, str):
keys = [keys]
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
for key in keys:
params.pop(key, None)
query = urlencode(params.multi_items())
return self.replace(query=query)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def __str__(self) -> str:
return self._url
def __repr__(self) -> str:
url = str(self)
if self.password:
url = str(self.replace(password="********"))
return f"{self.__class__.__name__}({repr(url)})"
class URLPath(str):
"""
A URL path string that may also hold an associated protocol and/or host.
Used by the routing to return `url_path_for` matches.
"""
def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
assert protocol in ("http", "websocket", "")
return str.__new__(cls, path)
def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
self.protocol = protocol
self.host = host
def make_absolute_url(self, base_url: str | URL) -> URL:
if isinstance(base_url, str):
base_url = URL(base_url)
if self.protocol:
scheme = {
"http": {True: "https", False: "http"},
"websocket": {True: "wss", False: "ws"},
}[self.protocol][base_url.is_secure]
else:
scheme = base_url.scheme
netloc = self.host or base_url.netloc
path = base_url.path.rstrip("/") + str(self)
return URL(scheme=scheme, netloc=netloc, path=path)
class Secret:
"""
Holds a string value that should not be revealed in tracebacks etc.
You should cast the value to `str` at the point it is required.
"""
def __init__(self, value: str):
self._value = value
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}('**********')"
def __str__(self) -> str:
return self._value
def __bool__(self) -> bool:
return bool(self._value)
class CommaSeparatedStrings(Sequence[str]):
def __init__(self, value: str | Sequence[str]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
splitter.whitespace = ","
splitter.whitespace_split = True
self._items = [item.strip() for item in splitter]
else:
self._items = list(value)
def __len__(self) -> int:
return len(self._items)
def __getitem__(self, index: int | slice) -> Any:
return self._items[index]
def __iter__(self) -> Iterator[str]:
return iter(self._items)
def __repr__(self) -> str:
class_name = self.__class__.__name__
items = [item for item in self]
return f"{class_name}({items!r})"
def __str__(self) -> str:
return ", ".join(repr(item) for item in self)
class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]):
_dict: dict[_KeyType, _CovariantValueType]
def __init__(
self,
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
| Mapping[_KeyType, _CovariantValueType]
| Iterable[tuple[_KeyType, _CovariantValueType]],
**kwargs: Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value: Any = args[0] if args else []
if kwargs:
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
if not value:
_items: list[tuple[Any, Any]] = []
elif hasattr(value, "multi_items"):
value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = cast(Mapping[_KeyType, _CovariantValueType], value)
_items = list(value.items())
else:
value = cast("list[tuple[Any, Any]]", value)
_items = list(value)
self._dict = {k: v for k, v in _items}
self._list = _items
def getlist(self, key: Any) -> list[_CovariantValueType]:
return [item_value for item_key, item_value in self._list if item_key == key]
def keys(self) -> KeysView[_KeyType]:
return self._dict.keys()
def values(self) -> ValuesView[_CovariantValueType]:
return self._dict.values()
def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
return self._dict.items()
def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
return list(self._list)
def __getitem__(self, key: _KeyType) -> _CovariantValueType:
return self._dict[key]
def __contains__(self, key: Any) -> bool:
return key in self._dict
def __iter__(self) -> Iterator[_KeyType]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._dict)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
items = self.multi_items()
return f"{class_name}({items!r})"
class MultiDict(ImmutableMultiDict[Any, Any]):
def __setitem__(self, key: Any, value: Any) -> None:
self.setlist(key, [value])
def __delitem__(self, key: Any) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]
def pop(self, key: Any, default: Any = None) -> Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
def popitem(self) -> tuple[Any, Any]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value
def poplist(self, key: Any) -> list[Any]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
def clear(self) -> None:
self._dict.clear()
self._list.clear()
def setdefault(self, key: Any, default: Any = None) -> Any:
if key not in self:
self._dict[key] = default
self._list.append((key, default))
return self[key]
def setlist(self, key: Any, values: list[Any]) -> None:
if not values:
self.pop(key, None)
else:
existing_items = [(k, v) for (k, v) in self._list if k != key]
self._list = existing_items + [(key, value) for value in values]
self._dict[key] = values[-1]
def append(self, key: Any, value: Any) -> None:
self._list.append((key, value))
self._dict[key] = value
def update(
self,
*args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
**kwargs: Any,
) -> None:
value = MultiDict(*args, **kwargs)
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
self._list = existing_items + value.multi_items()
self._dict.update(value)
class QueryParams(ImmutableMultiDict[str, str]):
"""
An immutable multidict.
"""
def __init__(
self,
*args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes,
**kwargs: Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value = args[0] if args else []
if isinstance(value, str):
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
elif isinstance(value, bytes):
super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
else:
super().__init__(*args, **kwargs) # type: ignore[arg-type]
self._list = [(str(k), str(v)) for k, v in self._list]
self._dict = {str(k): str(v) for k, v in self._dict.items()}
def __str__(self) -> str:
return urlencode(self._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
query_string = str(self)
return f"{class_name}({query_string!r})"
class UploadFile:
"""
An uploaded file included as part of the request data.
"""
def __init__(
self,
file: BinaryIO,
*,
size: int | None = None,
filename: str | None = None,
headers: Headers | None = None,
) -> None:
self.filename = filename
self.file = file
self.size = size
self.headers = headers or Headers()
# Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
# Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
self._max_mem_size = getattr(self.file, "_max_size", 0)
@property
def content_type(self) -> str | None:
return self.headers.get("content-type", None)
@property
def _in_memory(self) -> bool:
# check for SpooledTemporaryFile._rolled
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk
def _will_roll(self, size_to_add: int) -> bool:
# If we're not in_memory then we will always roll
if not self._in_memory:
return True
# Check for SpooledTemporaryFile._max_size
future_size = self.file.tell() + size_to_add
return bool(future_size > self._max_mem_size) if self._max_mem_size else False
async def write(self, data: bytes) -> None:
new_data_len = len(data)
if self.size is not None:
self.size += new_data_len
if self._will_roll(new_data_len):
await run_in_threadpool(self.file.write, data)
else:
self.file.write(data)
async def read(self, size: int = -1) -> bytes:
if self._in_memory:
return self.file.read(size)
return await run_in_threadpool(self.file.read, size)
async def seek(self, offset: int) -> None:
if self._in_memory:
self.file.seek(offset)
else:
await run_in_threadpool(self.file.seek, offset)
async def close(self) -> None:
if self._in_memory:
self.file.close()
else:
await run_in_threadpool(self.file.close)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
class FormData(ImmutableMultiDict[str, UploadFile | str]):
"""
An immutable multidict, containing both file uploads and text input.
"""
def __init__(
self,
*args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
**kwargs: str | UploadFile,
) -> None:
super().__init__(*args, **kwargs)
async def close(self) -> None:
for key, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()
class Headers(Mapping[str, str]):
"""
An immutable, case-insensitive multidict.
"""
def __init__(
self,
headers: Mapping[str, str] | None = None,
raw: list[tuple[bytes, bytes]] | None = None,
scope: MutableMapping[str, Any] | None = None,
) -> None:
self._list: list[tuple[bytes, bytes]] = []
if headers is not None:
assert raw is None, 'Cannot set both "headers" and "raw".'
assert scope is None, 'Cannot set both "headers" and "scope".'
self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
elif raw is not None:
assert scope is None, 'Cannot set both "raw" and "scope".'
self._list = raw
elif scope is not None:
# scope["headers"] isn't necessarily a list
# it might be a tuple or other iterable
self._list = scope["headers"] = list(scope["headers"])
@property
def raw(self) -> list[tuple[bytes, bytes]]:
return list(self._list)
def keys(self) -> list[str]: # type: ignore[override]
return [key.decode("latin-1") for key, value in self._list]
def values(self) -> list[str]: # type: ignore[override]
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
def getlist(self, key: str) -> list[str]:
get_header_key = key.lower().encode("latin-1")
return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
def mutablecopy(self) -> MutableHeaders:
return MutableHeaders(raw=self._list[:])
def __getitem__(self, key: str) -> str:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return header_value.decode("latin-1")
raise KeyError(key)
def __contains__(self, key: Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self) -> Iterator[Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
as_dict = dict(self.items())
if len(as_dict) == len(self):
return f"{class_name}({as_dict!r})"
return f"{class_name}(raw={self.raw!r})"
class MutableHeaders(Headers):
def __setitem__(self, key: str, value: str) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
found_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)
for idx in reversed(found_indexes[1:]):
del self._list[idx]
if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, set_value)
else:
self._list.append((set_key, set_value))
def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
del_key = key.lower().encode("latin-1")
pop_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
for idx in reversed(pop_indexes):
del self._list[idx]
def __ior__(self, other: Mapping[str, str]) -> MutableHeaders:
if not isinstance(other, Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
self.update(other)
return self
def __or__(self, other: Mapping[str, str]) -> MutableHeaders:
if not isinstance(other, Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
new = self.mutablecopy()
new.update(other)
return new
@property
def raw(self) -> list[tuple[bytes, bytes]]:
return self._list
def setdefault(self, key: str, value: str) -> str:
"""
If the header `key` does not exist, then set it to `value`.
Returns the header value.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
return item_value.decode("latin-1")
self._list.append((set_key, set_value))
return value
def update(self, other: Mapping[str, str]) -> None:
for key, val in other.items():
self[key] = val
def append(self, key: str, value: str) -> None:
"""
Append a header, preserving any duplicate entries.
"""
append_key = key.lower().encode("latin-1")
append_value = value.encode("latin-1")
self._list.append((append_key, append_value))
def add_vary_header(self, vary: str) -> None:
existing = self.get("vary")
if existing is not None:
vary = ", ".join([existing, vary])
self["vary"] = vary
class State:
"""
An object that can be used to store arbitrary state.
Used for `request.state` and `app.state`.
"""
_state: dict[str, Any]
def __init__(self, state: dict[str, Any] | None = None):
if state is None:
state = {}
super().__setattr__("_state", state)
def __setattr__(self, key: Any, value: Any) -> None:
self._state[key] = value
def __getattr__(self, key: Any) -> Any:
try:
return self._state[key]
except KeyError:
message = "'{}' object has no attribute '{}'"
raise AttributeError(message.format(self.__class__.__name__, key))
def __delattr__(self, key: Any) -> None:
del self._state[key]
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __iter__(self) -> Iterator[str]:
return iter(self._state)
def __len__(self) -> int:
return len(self._state)
================================================
FILE: starlette/endpoints.py
================================================
from __future__ import annotations
import json
from collections.abc import Callable, Generator
from typing import Any, Literal
from starlette import status
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket
class HTTPEndpoint:
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "http"
self.scope = scope
self.receive = receive
self.send = send
self._allowed_methods = [
method
for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
if getattr(self, method.lower(), None) is not None
]
def __await__(self) -> Generator[Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
request = Request(self.scope, receive=self.receive)
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
handler: Callable[[Request], Any] = getattr(self, handler_name, self.method_not_allowed)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
else:
response = await run_in_threadpool(handler, request)
await response(self.scope, self.receive, self.send)
async def method_not_allowed(self, request: Request) -> Response:
# If we're running inside a starlette application then raise an
# exception, so that the configurable exception handler can deal with
# returning the response. For plain ASGI apps, just return the response.
headers = {"Allow": ", ".join(self._allowed_methods)}
if "app" in self.scope:
raise HTTPException(status_code=405, headers=headers)
return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
class WebSocketEndpoint:
encoding: Literal["text", "bytes", "json"] | None = None
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "websocket"
self.scope = scope
self.receive = receive
self.send = send
def __await__(self) -> Generator[Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
await self.on_connect(websocket)
close_code = status.WS_1000_NORMAL_CLOSURE
try:
while True:
message = await websocket.receive()
if message["type"] == "websocket.receive":
data = await self.decode(websocket, message)
await self.on_receive(websocket, data)
elif message["type"] == "websocket.disconnect": # pragma: no branch
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
break
except Exception as exc:
close_code = status.WS_1011_INTERNAL_ERROR
raise exc
finally:
await self.on_disconnect(websocket, close_code)
async def decode(self, websocket: WebSocket, message: Message) -> Any:
if self.encoding == "text":
if "text" not in message:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Expected text websocket messages, but got bytes")
return message["text"]
elif self.encoding == "bytes":
if "bytes" not in message:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Expected bytes websocket messages, but got text")
return message["bytes"]
elif self.encoding == "json":
if message.get("text") is not None:
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
try:
return json.loads(text)
except json.decoder.JSONDecodeError:
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Malformed JSON data received.")
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
return message["text"] if message.get("text") else message["bytes"]
async def on_connect(self, websocket: WebSocket) -> None:
"""Override to handle an incoming websocket connection"""
await websocket.accept()
async def on_receive(self, websocket: WebSocket, data: Any) -> None:
"""Override to handle an incoming websocket message"""
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
"""Override to handle a disconnecting websocket"""
================================================
FILE: starlette/exceptions.py
================================================
from __future__ import annotations
import http
from collections.abc import Mapping
class HTTPException(Exception):
def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None:
if detail is None:
detail = http.HTTPStatus(status_code).phrase
self.status_code = status_code
self.detail = detail
self.headers = headers
def __str__(self) -> str:
return f"{self.status_code}: {self.detail}"
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"
class WebSocketException(Exception):
def __init__(self, code: int, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
def __str__(self) -> str:
return f"{self.code}: {self.reason}"
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"
================================================
FILE: starlette/formparsers.py
================================================
from __future__ import annotations
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from enum import Enum
from tempfile import SpooledTemporaryFile
from typing import TYPE_CHECKING
from urllib.parse import unquote_plus
from starlette.datastructures import FormData, Headers, UploadFile
if TYPE_CHECKING:
import python_multipart as multipart
from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
else:
try:
try:
import python_multipart as multipart
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
import multipart
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
multipart = None
parse_options_header = None
class FormMessage(Enum):
FIELD_START = 1
FIELD_NAME = 2
FIELD_DATA = 3
FIELD_END = 4
END = 5
@dataclass
class MultipartPart:
content_disposition: bytes | None = None
field_name: str = ""
data: bytearray = field(default_factory=bytearray)
file: UploadFile | None = None
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
try:
return src.decode(codec)
except (UnicodeDecodeError, LookupError):
return src.decode("latin-1")
class MultiPartException(Exception):
def __init__(self, message: str) -> None:
self.message = message
class FormParser:
def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages: list[tuple[FormMessage, bytes]] = []
def on_field_start(self) -> None:
message = (FormMessage.FIELD_START, b"")
self.messages.append(message)
def on_field_name(self, data: bytes, start: int, end: int) -> None:
message = (FormMessage.FIELD_NAME, data[start:end])
self.messages.append(message)
def on_field_data(self, data: bytes, start: int, end: int) -> None:
message = (FormMessage.FIELD_DATA, data[start:end])
self.messages.append(message)
def on_field_end(self) -> None:
message = (FormMessage.FIELD_END, b"")
self.messages.append(message)
def on_end(self) -> None:
message = (FormMessage.END, b"")
self.messages.append(message)
async def parse(self) -> FormData:
# Callbacks dictionary.
callbacks: QuerystringCallbacks = {
"on_field_start": self.on_field_start,
"on_field_name": self.on_field_name,
"on_field_data": self.on_field_data,
"on_field_end": self.on_field_end,
"on_end": self.on_end,
}
# Create the parser.
parser = multipart.QuerystringParser(callbacks)
field_name = bytearray()
field_value = bytearray()
items: list[tuple[str, str | UploadFile]] = []
# Feed the parser with data from the request.
async for chunk in self.stream:
if chunk:
parser.write(chunk)
else:
parser.finalize()
messages = list(self.messages)
self.messages.clear()
for message_type, message_bytes in messages:
if message_type == FormMessage.FIELD_START:
field_name = bytearray()
field_value = bytearray()
elif message_type == FormMessage.FIELD_NAME:
field_name.extend(message_bytes)
elif message_type == FormMessage.FIELD_DATA:
field_value.extend(message_bytes)
elif message_type == FormMessage.FIELD_END:
name = unquote_plus(field_name.decode("latin-1"))
value = unquote_plus(field_value.decode("latin-1"))
items.append((name, value))
return FormData(items)
class MultiPartParser:
spool_max_size = 1024 * 1024 # 1MB
"""The maximum size of the spooled temporary file used to store file data."""
max_part_size = 1024 * 1024 # 1MB
"""The maximum size of a part in the multipart request."""
def __init__(
self,
headers: Headers,
stream: AsyncGenerator[bytes, None],
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024, # 1MB
) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
self.max_fields = max_fields
self.items: list[tuple[str, str | UploadFile]] = []
self._current_files = 0
self._current_fields = 0
self._current_partial_header_name: bytes = b""
self._current_partial_header_value: bytes = b""
self._current_part = MultipartPart()
self._charset = ""
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: list[MultipartPart] = []
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
self.max_part_size = max_part_size
def on_part_begin(self) -> None:
self._current_part = MultipartPart()
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self._current_part.file is None:
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
self._current_part.data.extend(message_bytes)
else:
self._file_parts_to_write.append((self._current_part, message_bytes))
def on_part_end(self) -> None:
if self._current_part.file is None:
self.items.append(
(
self._current_part.field_name,
_user_safe_decode(self._current_part.data, self._charset),
)
)
else:
self._file_parts_to_finish.append(self._current_part)
# The file can be added to the items right now even though it's not
# finished yet, because it will be finished in the `parse()` method, before
# self.items is used in the return value.
self.items.append((self._current_part.field_name, self._current_part.file))
def on_header_field(self, data: bytes, start: int, end: int) -> None:
self._current_partial_header_name += data[start:end]
def on_header_value(self, data: bytes, start: int, end: int) -> None:
self._current_partial_header_value += data[start:end]
def on_header_end(self) -> None:
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
self._current_part.item_headers.append((field, self._current_partial_header_value))
self._current_partial_header_name = b""
self._current_partial_header_value = b""
def on_headers_finished(self) -> None:
disposition, options = parse_options_header(self._current_part.content_disposition)
try:
self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
except KeyError:
raise MultiPartException('The Content-Disposition header field "name" must be provided.')
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = UploadFile(
file=tempfile, # type: ignore[arg-type]
size=0,
filename=filename,
headers=Headers(raw=self._current_part.item_headers),
)
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
self._current_part.file = None
def on_end(self) -> None:
pass
async def parse(self) -> FormData:
# Parse the Content-Type header to get the multipart boundary.
_, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if isinstance(charset, bytes):
charset = charset.decode("latin-1")
self._charset = charset
try:
boundary = params[b"boundary"]
except KeyError:
raise MultiPartException("Missing boundary in multipart.")
# Callbacks dictionary.
callbacks: MultipartCallbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,
"on_header_field": self.on_header_field,
"on_header_value": self.on_header_value,
"on_header_end": self.on_header_end,
"on_headers_finished": self.on_headers_finished,
"on_end": self.on_end,
}
# Create the parser.
parser = multipart.MultipartParser(boundary, callbacks)
try:
# Feed the parser with data from the request.
async for chunk in self.stream:
parser.write(chunk)
# Write file data, it needs to use await with the UploadFile methods
# that call the corresponding file methods *in a threadpool*,
# otherwise, if they were called directly in the callback methods above
# (regular, non-async functions), that would block the event loop in
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
await part.file.write(data)
for part in self._file_parts_to_finish:
assert part.file # for type checkers
await part.file.seek(0)
self._file_parts_to_write.clear()
self._file_parts_to_finish.clear()
parser.finalize()
except MultiPartException as exc:
# Close all the files if there was an error.
for file in self._files_to_close_on_error:
file.close()
raise exc
return FormData(self.items)
================================================
FILE: starlette/middleware/__init__.py
================================================
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterator
from typing import Any, ParamSpec, Protocol
P = ParamSpec("P")
_Scope = Any
_Receive = Callable[[], Awaitable[Any]]
_Send = Callable[[Any], Awaitable[None]]
# Since `starlette.types.ASGIApp` type differs from `ASGIApplication` from `asgiref`
# we need to define a more permissive version of ASGIApp that doesn't cause type errors.
_ASGIApp = Callable[[_Scope, _Receive, _Send], Awaitable[None]]
class _MiddlewareFactory(Protocol[P]):
def __call__(self, app: _ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> _ASGIApp: ... # pragma: no cover
class Middleware:
def __init__(self, cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None:
self.cls = cls
self.args = args
self.kwargs = kwargs
def __iter__(self) -> Iterator[Any]:
as_tuple = (self.cls, self.args, self.kwargs)
return iter(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
name = getattr(self.cls, "__name__", "")
args_repr = ", ".join([name] + args_strings + option_strings)
return f"{class_name}({args_repr})"
================================================
FILE: starlette/middleware/authentication.py
================================================
from __future__ import annotations
from collections.abc import Callable
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
UnauthenticatedUser,
)
from starlette.requests import HTTPConnection
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
class AuthenticationMiddleware:
def __init__(
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
) -> None:
self.app = app
self.backend = backend
self.on_error: Callable[[HTTPConnection, AuthenticationError], Response] = (
on_error if on_error is not None else self.default_on_error
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http", "websocket"]:
await self.app(scope, receive, send)
return
conn = HTTPConnection(scope)
try:
auth_result = await self.backend.authenticate(conn)
except AuthenticationError as exc:
response = self.on_error(conn, exc)
if scope["type"] == "websocket":
await send({"type": "websocket.close", "code": 1000})
else:
await response(scope, receive, send)
return
if auth_result is None:
auth_result = AuthCredentials(), UnauthenticatedUser()
scope["auth"], scope["user"] = auth_result
await self.app(scope, receive, send)
@staticmethod
def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
return PlainTextResponse(str(exc), status_code=400)
================================================
FILE: starlette/middleware/base.py
================================================
from __future__ import annotations
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Mapping, MutableMapping
from typing import Any, TypeVar
import anyio
from starlette._utils import collapse_excgroups
from starlette.requests import ClientDisconnect, Request
from starlette.responses import Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
BodyStreamGenerator = AsyncGenerator[bytes | MutableMapping[str, Any], None]
AsyncContentStream = AsyncIterable[str | bytes | memoryview | MutableMapping[str, Any]]
T = TypeVar("T")
class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
we cache the entire request body in memory and pass that to downstream middlewares,
but if they call Request.stream() then all we do is send an
empty body so that downstream things don't hang forever.
"""
def __init__(self, scope: Scope, receive: Receive):
super().__init__(scope, receive)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False
self._wrapped_rc_stream = self.stream()
async def wrapped_receive(self) -> Message:
# wrapped_rcv state 1: disconnected
if self._wrapped_rcv_disconnected:
# we've already sent a disconnect to the downstream app
# we don't need to wait to get another one
# (although most ASGI servers will just keep sending it)
return {"type": "http.disconnect"}
# wrapped_rcv state 1: consumed but not yet disconnected
if self._wrapped_rcv_consumed:
# since the downstream app has consumed us all that is left
# is to send it a disconnect
if self._is_disconnected:
# the middleware has already seen the disconnect
# since we know the client is disconnected no need to wait
# for the message
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
# we don't know yet if the client is disconnected or not
# so we'll wait until we get that message
msg = await self.receive()
if msg["type"] != "http.disconnect": # pragma: no cover
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
self._wrapped_rcv_disconnected = True
return msg
# wrapped_rcv state 3: not yet consumed
if getattr(self, "_body", None) is not None:
# body() was called, we return it even if the client disconnected
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": self._body,
"more_body": False,
}
elif self._stream_consumed:
# stream() was called to completion
# return an empty body so that downstream apps don't hang
# waiting for a disconnect
self._wrapped_rcv_consumed = True
return {
"type": "http.request",
"body": b"",
"more_body": False,
}
else:
# body() was never called and stream() wasn't consumed
try:
stream = self.stream()
chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
"type": "http.request",
"body": chunk,
"more_body": not self._stream_consumed,
}
except ClientDisconnect:
self._wrapped_rcv_disconnected = True
return {"type": "http.disconnect"}
class BaseHTTPMiddleware:
def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()
app_exc: Exception | None = None
exception_already_raised = False
async def call_next(request: Request) -> Response:
async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}
async with anyio.create_task_group() as task_group:
async def wrap(func: Callable[[], Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result
task_group.start_soon(wrap, response_sent.wait)
message = await wrap(wrapped_receive)
if response_sent.is_set():
return {"type": "http.disconnect"}
return message
async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. response_sent has been set.
return
async def coro() -> None:
nonlocal app_exc
with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
task_group.start_soon(coro)
try:
message = await recv_stream.receive()
info = message.get("info", None)
if message["type"] == "http.response.debug" and info is not None:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
nonlocal exception_already_raised
exception_already_raised = True
# Prevent `anyio.EndOfStream` from polluting app exception context.
# If both cause and context are None then the context is suppressed
# and `anyio.EndOfStream` is not present in the exception traceback.
# If exception cause is not None then it is propagated with
# reraising here.
# If exception has no cause but has context set then the context is
# propagated as a cause with the reraise. This is necessary in order
# to prevent `anyio.EndOfStream` from polluting the exception
# context.
raise app_exc from app_exc.__cause__ or app_exc.__context__
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> BodyStreamGenerator:
async for message in recv_stream:
if message["type"] == "http.response.pathsend":
yield message
break
assert message["type"] == "http.response.body", f"Unexpected message: {message}"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
response.raw_headers = message["headers"]
return response
streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream()
send_stream, recv_stream = streams
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
recv_stream.close()
if app_exc is not None and not exception_already_raised:
raise app_exc
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
class _StreamingResponse(Response):
def __init__(
self,
content: AsyncContentStream,
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
info: Mapping[str, Any] | None = None,
) -> None:
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)
self.background = None
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
should_close_body = True
async for chunk in self.body_iterator:
if isinstance(chunk, dict):
# We got an ASGI message which is not response body (eg: pathsend)
should_close_body = False
await send(chunk)
continue
await send({"type": "http.response.body", "body": chunk, "more_body": True})
if should_close_body:
await send({"type": "http.response.body", "body": b"", "more_body": False})
if self.background:
await self.background()
================================================
FILE: starlette/middleware/cors.py
================================================
from __future__ import annotations
import functools
import re
from collections.abc import Sequence
from starlette.datastructures import Headers, MutableHeaders
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
class CORSMiddleware:
def __init__(
self,
app: ASGIApp,
allow_origins: Sequence[str] = (),
allow_methods: Sequence[str] = ("GET",),
allow_headers: Sequence[str] = (),
allow_credentials: bool = False,
allow_origin_regex: str | None = None,
allow_private_network: bool = False,
expose_headers: Sequence[str] = (),
max_age: int = 600,
) -> None:
if "*" in allow_methods:
allow_methods = ALL_METHODS
compiled_allow_origin_regex = None
if allow_origin_regex is not None:
compiled_allow_origin_regex = re.compile(allow_origin_regex)
allow_all_origins = "*" in allow_origins
allow_all_headers = "*" in allow_headers
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
simple_headers: dict[str, str] = {}
if allow_all_origins:
simple_headers["Access-Control-Allow-Origin"] = "*"
if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if expose_headers:
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
preflight_headers: dict[str, str] = {}
if preflight_explicit_allow_origin:
# The origin value will be set in preflight_response() if it is allowed.
preflight_headers["Vary"] = "Origin"
else:
preflight_headers["Access-Control-Allow-Origin"] = "*"
preflight_headers.update(
{
"Access-Control-Allow-Methods": ", ".join(allow_methods),
"Access-Control-Max-Age": str(max_age),
}
)
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
if allow_headers and not allow_all_headers:
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
if allow_credentials:
preflight_headers["Access-Control-Allow-Credentials"] = "true"
self.app = app
self.allow_origins = allow_origins
self.allow_methods = allow_methods
self.allow_headers = [h.lower() for h in allow_headers]
self.allow_all_origins = allow_all_origins
self.allow_all_headers = allow_all_headers
self.allow_credentials = allow_credentials
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
self.allow_origin_regex = compiled_allow_origin_regex
self.allow_private_network = allow_private_network
self.simple_headers = simple_headers
self.preflight_headers = preflight_headers
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return
method = scope["method"]
headers = Headers(scope=scope)
origin = headers.get("origin")
if origin is None:
await self.app(scope, receive, send)
return
if method == "OPTIONS" and "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
await response(scope, receive, send)
return
await self.simple_response(scope, receive, send, request_headers=headers)
def is_allowed_origin(self, origin: str) -> bool:
if self.allow_all_origins:
return True
if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
return True
return origin in self.allow_origins
def preflight_response(self, request_headers: Headers) -> Response:
requested_origin = request_headers["origin"]
requested_method = request_headers["access-control-request-method"]
requested_headers = request_headers.get("access-control-request-headers")
requested_private_network = request_headers.get("access-control-request-private-network")
headers = dict(self.preflight_headers)
failures: list[str] = []
if self.is_allowed_origin(origin=requested_origin):
if self.preflight_explicit_allow_origin:
# The "else" case is already accounted for in self.preflight_headers
# and the value would be "*".
headers["Access-Control-Allow-Origin"] = requested_origin
else:
failures.append("origin")
if requested_method not in self.allow_methods:
failures.append("method")
# If we allow all headers, then we have to mirror back any requested
# headers in the response.
if self.allow_all_headers and requested_headers is not None:
headers["Access-Control-Allow-Headers"] = requested_headers
elif requested_headers is not None:
for header in [h.lower() for h in requested_headers.split(",")]:
if header.strip() not in self.allow_headers:
failures.append("headers")
break
if requested_private_network is not None:
if self.allow_private_network:
headers["Access-Control-Allow-Private-Network"] = "true"
else:
failures.append("private-network")
# We don't strictly need to use 400 responses here, since its up to
# the browser to enforce the CORS policy, but its more informative
# if we do.
if failures:
failure_text = "Disallowed CORS " + ", ".join(failures)
return PlainTextResponse(failure_text, status_code=400, headers=headers)
return PlainTextResponse("OK", status_code=200, headers=headers)
async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
send = functools.partial(self.send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)
async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
if message["type"] != "http.response.start":
await send(message)
return
message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
headers.update(self.simple_headers)
origin = request_headers["Origin"]
# If credentials are allowed, then we must respond with the specific origin instead of '*'.
if self.allow_all_origins and self.allow_credentials:
self.allow_explicit_origin(headers, origin)
# If we only allow specific origins, then we have to mirror back the Origin header in the response.
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
self.allow_explicit_origin(headers, origin)
await send(message)
@staticmethod
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
headers["Access-Control-Allow-Origin"] = origin
headers.add_vary_header("Origin")
================================================
FILE: starlette/middleware/errors.py
================================================
from __future__ import annotations
import html
import inspect
import sys
import traceback
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
STYLES = """
p {
color: #211c1c;
}
.traceback-container {
border: 1px solid #038BB8;
}
.traceback-title {
background-color: #038BB8;
color: lemonchiffon;
padding: 12px;
font-size: 20px;
margin-top: 0px;
}
.frame-line {
padding-left: 10px;
font-family: monospace;
}
.frame-filename {
font-family: monospace;
}
.center-line {
background-color: #038BB8;
color: #f9f6e1;
padding: 5px 0px 5px 5px;
}
.lineno {
margin-right: 5px;
}
.frame-title {
font-weight: unset;
padding: 10px 10px 10px 10px;
background-color: #E4F4FD;
margin-right: 10px;
color: #191f21;
font-size: 17px;
border: 1px solid #c7dce8;
}
.collapse-btn {
float: right;
padding: 0px 5px 1px 5px;
border: solid 1px #96aebb;
cursor: pointer;
}
.collapsed {
display: none;
}
.source-code {
font-family: courier;
font-size: small;
padding-bottom: 10px;
}
"""
JS = """
"""
TEMPLATE = """
Starlette Debugger
500 Server Error
{error}
Traceback
{exc_html}
{js}
"""
FRAME_TEMPLATE = """
File {frame_filename},
line {frame_lineno},
in {frame_name}{collapse_button}
{code_context}
""" # noqa: E501
LINE = """
{lineno}. {line}
"""
CENTER_LINE = """
{lineno}. {line}
"""
class ServerErrorMiddleware:
"""
Handles returning 500 responses when a server error occurs.
If 'debug' is set, then traceback responses will be returned,
otherwise the designated 'handler' will be called.
This middleware class should generally be used to wrap *everything*
else up, so that unhandled exceptions anywhere in the stack
always result in an appropriate 500 response.
"""
def __init__(
self,
app: ASGIApp,
handler: ExceptionHandler | None = None,
debug: bool = False,
) -> None:
self.app = app
self.handler = handler
self.debug = debug
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
response_started = False
async def _send(message: Message) -> None:
nonlocal response_started, send
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await self.app(scope, receive, _send)
except Exception as exc:
request = Request(scope)
if self.debug:
# In debug mode, return traceback responses.
response = self.debug_response(request, exc)
elif self.handler is None:
# Use our default 500 error handler.
response = self.error_response(request, exc)
else:
# Use an installed 500 error handler.
if is_async_callable(self.handler):
response = await self.handler(request, exc)
else:
response = await run_in_threadpool(self.handler, request, exc)
if not response_started:
await response(scope, receive, send)
# We always continue to raise the exception.
# This allows servers to log the error, or allows test clients
# to optionally raise the error within the test case.
raise exc
def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
values = {
# HTML escape - line could contain < or >
"line": html.escape(line).replace(" ", " "),
"lineno": (frame_lineno - frame_index) + index,
}
if index != frame_index:
return LINE.format(**values)
return CENTER_LINE.format(**values)
def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str:
code_context = "".join(
self.format_line(
index,
line,
frame.lineno,
frame.index, # type: ignore[arg-type]
)
for index, line in enumerate(frame.code_context or [])
)
values = {
# HTML escape - filename could contain < or >, especially if it's a virtual
# file e.g. in the REPL
"frame_filename": html.escape(frame.filename),
"frame_lineno": frame.lineno,
# HTML escape - if you try very hard it's possible to name a function with <
# or >
"frame_name": html.escape(frame.function),
"code_context": code_context,
"collapsed": "collapsed" if is_collapsed else "",
"collapse_button": "+" if is_collapsed else "‒",
}
return FRAME_TEMPLATE.format(**values)
def generate_html(self, exc: Exception, limit: int = 7) -> str:
traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
exc_html = ""
is_collapsed = False
exc_traceback = exc.__traceback__
if exc_traceback is not None:
frames = inspect.getinnerframes(exc_traceback, limit)
for frame in reversed(frames):
exc_html += self.generate_frame_html(frame, is_collapsed)
is_collapsed = True
if sys.version_info >= (3, 13): # pragma: no cover
exc_type_str = traceback_obj.exc_type_str
else: # pragma: no cover
exc_type_str = traceback_obj.exc_type.__name__
# escape error class and text
error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}"
return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html)
def generate_plain_text(self, exc: Exception) -> str:
return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
def debug_response(self, request: Request, exc: Exception) -> Response:
accept = request.headers.get("accept", "")
if "text/html" in accept:
content = self.generate_html(exc)
return HTMLResponse(content, status_code=500)
content = self.generate_plain_text(exc)
return PlainTextResponse(content, status_code=500)
def error_response(self, request: Request, exc: Exception) -> Response:
return PlainTextResponse("Internal Server Error", status_code=500)
================================================
FILE: starlette/middleware/exceptions.py
================================================
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from starlette._exception_handler import (
ExceptionHandlers,
StatusHandlers,
wrap_app_handling_exceptions,
)
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
from starlette.websockets import WebSocket
class ExceptionMiddleware:
def __init__(
self,
app: ASGIApp,
handlers: Mapping[Any, ExceptionHandler] | None = None,
debug: bool = False,
) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers: StatusHandlers = {}
self._exception_handlers: ExceptionHandlers = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
}
if handlers is not None: # pragma: no branch
for key, value in handlers.items():
self.add_exception_handler(key, value)
def add_exception_handler(
self,
exc_class_or_status_code: int | type[Exception],
handler: ExceptionHandler,
) -> None:
if isinstance(exc_class_or_status_code, int):
self._status_handlers[exc_class_or_status_code] = handler
else:
assert issubclass(exc_class_or_status_code, Exception)
self._exception_handlers[exc_class_or_status_code] = handler
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return
scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)
conn: Request | WebSocket
if scope["type"] == "http":
conn = Request(scope, receive, send)
else:
conn = WebSocket(scope, receive, send)
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
async def http_exception(self, request: Request, exc: Exception) -> Response:
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
================================================
FILE: starlette/middleware/gzip.py
================================================
import gzip
import io
from typing import NoReturn
from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
class GZipMiddleware:
def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
self.app = app
self.minimum_size = minimum_size
self.compresslevel = compresslevel
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
responder: ASGIApp
if "gzip" in headers.get("Accept-Encoding", ""):
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
else:
responder = IdentityResponder(self.app, self.minimum_size)
await responder(scope, receive, send)
class IdentityResponder:
content_encoding: str
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
self.app = app
self.minimum_size = minimum_size
self.send: Send = unattached_send
self.initial_message: Message = {}
self.started = False
self.content_encoding_set = False
self.content_type_is_excluded = False
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
await self.app(scope, receive, self.send_with_compression)
async def send_with_compression(self, message: Message) -> None:
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
# modify the outgoing headers correctly.
self.initial_message = message
headers = Headers(raw=self.initial_message["headers"])
self.content_encoding_set = "content-encoding" in headers
self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
if not self.started:
self.started = True
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body" and not self.started:
self.started = True
body = message.get("body", b"")
more_body = message.get("more_body", False)
if len(body) < self.minimum_size and not more_body:
# Don't apply compression to small outgoing responses.
await self.send(self.initial_message)
await self.send(message)
elif not more_body:
# Standard response.
body = self.apply_compression(body, more_body=False)
headers = MutableHeaders(raw=self.initial_message["headers"])
headers.add_vary_header("Accept-Encoding")
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
headers["Content-Length"] = str(len(body))
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
else:
# Initial body in streaming response.
body = self.apply_compression(body, more_body=True)
headers = MutableHeaders(raw=self.initial_message["headers"])
headers.add_vary_header("Accept-Encoding")
if body != message["body"]:
headers["Content-Encoding"] = self.content_encoding
del headers["Content-Length"]
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body":
# Remaining body in streaming response.
body = message.get("body", b"")
more_body = message.get("more_body", False)
message["body"] = self.apply_compression(body, more_body=more_body)
await self.send(message)
elif message_type == "http.response.pathsend": # pragma: no branch
# Don't apply GZip to pathsend responses
await self.send(self.initial_message)
await self.send(message)
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
"""Apply compression on the response body.
If more_body is False, any compression file should be closed. If it
isn't, it won't be closed automatically until all background tasks
complete.
"""
return body
class GZipResponder(IdentityResponder):
content_encoding = "gzip"
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
super().__init__(app, minimum_size)
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
with self.gzip_buffer, self.gzip_file:
await super().__call__(scope, receive, send)
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
self.gzip_file.write(body)
if not more_body:
self.gzip_file.close()
body = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
return body
async def unattached_send(message: Message) -> NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover
================================================
FILE: starlette/middleware/httpsredirect.py
================================================
from starlette.datastructures import URL
from starlette.responses import RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
class HTTPSRedirectMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"):
url = URL(scope=scope)
redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme]
netloc = url.hostname if url.port in (80, 443) else url.netloc
url = url.replace(scheme=redirect_scheme, netloc=netloc)
response = RedirectResponse(url, status_code=307)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
================================================
FILE: starlette/middleware/sessions.py
================================================
from __future__ import annotations
import json
import typing
from base64 import b64decode, b64encode
from typing import Literal
import itsdangerous
from itsdangerous.exc import BadSignature
from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send
class SessionMiddleware:
def __init__(
self,
app: ASGIApp,
secret_key: str | Secret,
session_cookie: str = "session",
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
path: str = "/",
same_site: Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
domain: str | None = None,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
self.session_cookie = session_cookie
self.max_age = max_age
self.path = path
self.security_flags = "httponly; samesite=" + same_site
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
if domain is not None:
self.security_flags += f"; domain={domain}"
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
return
connection = HTTPConnection(scope)
initial_session_was_empty = True
if self.session_cookie in connection.cookies:
data = connection.cookies[self.session_cookie].encode("utf-8")
try:
data = self.signer.unsign(data, max_age=self.max_age)
scope["session"] = Session(json.loads(b64decode(data)))
initial_session_was_empty = False
except BadSignature:
scope["session"] = Session()
else:
scope["session"] = Session()
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
session: Session = scope["session"]
headers = MutableHeaders(scope=message)
if session.accessed:
headers.add_vary_header("Cookie")
if session.modified and session:
# We have session data to persist.
data = b64encode(json.dumps(session).encode("utf-8"))
data = self.signer.sign(data)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
elif session.modified and not initial_session_was_empty:
# The session has been cleared.
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
session_cookie=self.session_cookie,
data="null",
path=self.path,
expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
await send(message)
await self.app(scope, receive, send_wrapper)
class Session(dict[str, typing.Any]):
accessed: bool = False
modified: bool = False
def mark_accessed(self) -> None:
self.accessed = True
def mark_modified(self) -> None:
self.accessed = True
self.modified = True
def __setitem__(self, key: str, value: typing.Any) -> None:
self.mark_modified()
super().__setitem__(key, value)
def __delitem__(self, key: str) -> None:
self.mark_modified()
super().__delitem__(key)
def clear(self) -> None:
self.mark_modified()
super().clear()
def pop(self, key: str, *args: typing.Any) -> typing.Any:
self.modified = self.modified or key in self
return super().pop(key, *args)
def setdefault(self, key: str, default: typing.Any = None) -> typing.Any:
if key not in self:
self.mark_modified()
return super().setdefault(key, default)
def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.mark_modified()
super().update(*args, **kwargs)
================================================
FILE: starlette/middleware/trustedhost.py
================================================
from __future__ import annotations
from collections.abc import Sequence
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
class TrustedHostMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_hosts: Sequence[str] | None = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
allowed_hosts = ["*"]
for pattern in allowed_hosts:
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
if pattern.startswith("*") and pattern != "*":
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
self.app = app
self.allowed_hosts = list(allowed_hosts)
self.allow_any = "*" in allowed_hosts
self.www_redirect = www_redirect
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.allow_any or scope["type"] not in (
"http",
"websocket",
): # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
is_valid_host = True
break
elif "www." + host == pattern:
found_www_redirect = True
if is_valid_host:
await self.app(scope, receive, send)
else:
response: Response
if found_www_redirect and self.www_redirect:
url = URL(scope=scope)
redirect_url = url.replace(netloc="www." + url.netloc)
response = RedirectResponse(url=str(redirect_url))
else:
response = PlainTextResponse("Invalid host header", status_code=400)
await response(scope, receive, send)
================================================
FILE: starlette/middleware/wsgi.py
================================================
from __future__ import annotations
import io
import math
import sys
import warnings
from collections.abc import Callable, MutableMapping
from typing import Any
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from starlette.types import Receive, Scope, Send
warnings.warn(
"starlette.middleware.wsgi is deprecated and will be removed in a future release. "
"Please refer to https://github.com/abersheeran/a2wsgi as a replacement.",
DeprecationWarning,
stacklevel=2,
)
def build_environ(scope: Scope, body: bytes) -> dict[str, Any]:
"""
Builds a scope and request body into a WSGI environ object.
"""
script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
path_info = scope["path"].encode("utf8").decode("latin1")
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": script_name,
"PATH_INFO": path_info,
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": io.BytesIO(body),
"wsgi.errors": sys.stdout,
"wsgi.multithread": True,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
# Get server name and port - required in WSGI, not in ASGI
server = scope.get("server") or ("localhost", 80)
environ["SERVER_NAME"] = server[0]
environ["SERVER_PORT"] = server[1]
# Get client IP address
if scope.get("client"):
environ["REMOTE_ADDR"] = scope["client"][0]
# Go through headers and make them into environ entries
for name, value in scope.get("headers", []):
name = name.decode("latin1")
if name == "content-length":
corrected_name = "CONTENT_LENGTH"
elif name == "content-type":
corrected_name = "CONTENT_TYPE"
else:
corrected_name = f"HTTP_{name}".upper().replace("-", "_")
# HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in
# case
value = value.decode("latin1")
if corrected_name in environ:
value = environ[corrected_name] + "," + value
environ[corrected_name] = value
return environ
class WSGIMiddleware:
def __init__(self, app: Callable[..., Any]) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "http"
responder = WSGIResponder(self.app, scope)
await responder(receive, send)
class WSGIResponder:
stream_send: ObjectSendStream[MutableMapping[str, Any]]
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
def __init__(self, app: Callable[..., Any], scope: Scope) -> None:
self.app = app
self.scope = scope
self.status = None
self.response_headers = None
self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
self.response_started = False
self.exc_info: Any = None
async def __call__(self, receive: Receive, send: Send) -> None:
body = b""
more_body = True
while more_body:
message = await receive()
body += message.get("body", b"")
more_body = message.get("more_body", False)
environ = build_environ(self.scope, body)
async with anyio.create_task_group() as task_group:
task_group.start_soon(self.sender, send)
async with self.stream_send:
await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
if self.exc_info is not None:
raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
async def sender(self, send: Send) -> None:
async with self.stream_receive:
async for message in self.stream_receive:
await send(message)
def start_response(
self,
status: str,
response_headers: list[tuple[str, str]],
exc_info: Any = None,
) -> None:
self.exc_info = exc_info
if not self.response_started: # pragma: no branch
self.response_started = True
status_code_string, _ = status.split(" ", 1)
status_code = int(status_code_string)
headers = [
(name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
for name, value in response_headers
]
anyio.from_thread.run(
self.stream_send.send,
{
"type": "http.response.start",
"status": status_code,
"headers": headers,
},
)
def wsgi(
self,
environ: dict[str, Any],
start_response: Callable[..., Any],
) -> None:
for chunk in self.app(environ, start_response):
anyio.from_thread.run(
self.stream_send.send,
{"type": "http.response.body", "body": chunk, "more_body": True},
)
anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
================================================
FILE: starlette/py.typed
================================================
================================================
FILE: starlette/requests.py
================================================
from __future__ import annotations
import json
import sys
from collections.abc import AsyncGenerator, Iterator, Mapping
from http import cookies as http_cookies
from typing import TYPE_CHECKING, Any, Generic, NoReturn, cast
import anyio
from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
if TYPE_CHECKING:
from python_multipart.multipart import parse_options_header
from starlette.applications import Starlette
from starlette.middleware.sessions import Session
from starlette.routing import Router
else:
try:
try:
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
parse_options_header = None
if sys.version_info >= (3, 13): # pragma: no cover
from typing import TypeVar
else: # pragma: no cover
from typing_extensions import TypeVar
SERVER_PUSH_HEADERS_TO_COPY = {
"accept",
"accept-encoding",
"accept-language",
"cache-control",
"user-agent",
}
def cookie_parser(cookie_string: str) -> dict[str, str]:
"""
This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
It attempts to mimic browser cookie parsing behavior: browsers and web servers
frequently disregard the spec (RFC 6265) when setting and reading cookies,
so we attempt to suit the common scenarios here.
This function has been adapted from Django 3.1.0.
Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
on an outdated spec and will fail on lots of input we want to support
"""
cookie_dict: dict[str, str] = {}
for chunk in cookie_string.split(";"):
if "=" in chunk:
key, val = chunk.split("=", 1)
else:
# Assume an empty name per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
key, val = "", chunk
key, val = key.strip(), val.strip()
if key or val:
# unquote using Python's algorithm.
cookie_dict[key] = http_cookies._unquote(val)
return cookie_dict
class ClientDisconnect(Exception):
pass
StateT = TypeVar("StateT", bound=Mapping[str, Any] | State, default=State)
class HTTPConnection(Mapping[str, Any], Generic[StateT]):
"""
A base class for incoming HTTP connections, that is used to provide
any functionality that is common to both `Request` and `WebSocket`.
"""
def __init__(self, scope: Scope, receive: Receive | None = None) -> None:
assert scope["type"] in ("http", "websocket")
self.scope = scope
def __getitem__(self, key: str) -> Any:
return self.scope[key]
def __iter__(self) -> Iterator[str]:
return iter(self.scope)
def __len__(self) -> int:
return len(self.scope)
# Don't use the `abc.Mapping.__eq__` implementation.
# Connection instances should never be considered equal
# unless `self is other`.
__eq__ = object.__eq__
__hash__ = object.__hash__
@property
def app(self) -> Any:
return self.scope["app"]
@property
def url(self) -> URL:
if not hasattr(self, "_url"): # pragma: no branch
self._url = URL(scope=self.scope)
return self._url
@property
def base_url(self) -> URL:
if not hasattr(self, "_base_url"):
base_url_scope = dict(self.scope)
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
path = app_root_path
if not path.endswith("/"):
path += "/"
base_url_scope["path"] = path
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = app_root_path
self._base_url = URL(scope=base_url_scope)
return self._base_url
@property
def headers(self) -> Headers:
if not hasattr(self, "_headers"):
self._headers = Headers(scope=self.scope)
return self._headers
@property
def query_params(self) -> QueryParams:
if not hasattr(self, "_query_params"): # pragma: no branch
self._query_params = QueryParams(self.scope["query_string"])
return self._query_params
@property
def path_params(self) -> dict[str, Any]:
return self.scope.get("path_params", {})
@property
def cookies(self) -> dict[str, str]:
if not hasattr(self, "_cookies"):
cookies: dict[str, str] = {}
cookie_headers = self.headers.getlist("cookie")
for header in cookie_headers:
cookies.update(cookie_parser(header))
self._cookies = cookies
return self._cookies
@property
def client(self) -> Address | None:
# client is a 2 item tuple of (host, port), None if missing
host_port = self.scope.get("client")
if host_port is not None:
return Address(*host_port)
return None
@property
def session(self) -> dict[str, Any]:
assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
session: Session = self.scope["session"]
# We keep the hasattr in case people actually use their own `SessionMiddleware` implementation.
if hasattr(session, "mark_accessed"): # pragma: no branch
session.mark_accessed()
return session
@property
def auth(self) -> Any:
assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
return self.scope["auth"]
@property
def user(self) -> Any:
assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
return self.scope["user"]
@property
def state(self) -> StateT:
if not hasattr(self, "_state"):
# Ensure 'state' has an empty dict if it's not already populated.
self.scope.setdefault("state", {})
# Create a state instance with a reference to the dict in which it should
# store info
self._state = State(self.scope["state"])
return cast(StateT, self._state)
def url_for(self, name: str, /, **path_params: Any) -> URL:
url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app")
if url_path_provider is None:
raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.")
url_path = url_path_provider.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.base_url)
async def empty_receive() -> NoReturn:
raise RuntimeError("Receive channel has not been made available")
async def empty_send(message: Message) -> NoReturn:
raise RuntimeError("Send channel has not been made available")
class Request(HTTPConnection[StateT]):
_form: FormData | None
def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
super().__init__(scope)
assert scope["type"] == "http"
self._receive = receive
self._send = send
self._stream_consumed = False
self._is_disconnected = False
self._form = None
@property
def method(self) -> str:
return cast(str, self.scope["method"])
@property
def receive(self) -> Receive:
return self._receive
async def stream(self) -> AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body
yield b""
return
if self._stream_consumed:
raise RuntimeError("Stream consumed")
while not self._stream_consumed:
message = await self._receive()
if message["type"] == "http.request":
body = message.get("body", b"")
if not message.get("more_body", False):
self._stream_consumed = True
if body:
yield body
elif message["type"] == "http.disconnect": # pragma: no branch
self._is_disconnected = True
raise ClientDisconnect()
yield b""
async def body(self) -> bytes:
if not hasattr(self, "_body"):
chunks: list[bytes] = []
async for chunk in self.stream():
chunks.append(chunk)
self._body = b"".join(chunks)
return self._body
async def json(self) -> Any:
if not hasattr(self, "_json"): # pragma: no branch
body = await self.body()
self._json = json.loads(body)
return self._json
async def _get_form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> FormData:
if self._form is None: # pragma: no branch
assert parse_options_header is not None, (
"The `python-multipart` library must be installed to use form parsing."
)
content_type_header = self.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
try:
multipart_parser = MultiPartParser(
self.headers,
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_part_size=max_part_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
if "app" in self.scope:
raise HTTPException(status_code=400, detail=exc.message)
raise exc
elif content_type == b"application/x-www-form-urlencoded":
form_parser = FormParser(self.headers, self.stream())
self._form = await form_parser.parse()
else:
self._form = FormData()
return self._form
def form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_part_size: int = 1024 * 1024,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
)
async def close(self) -> None:
if self._form is not None: # pragma: no branch
await self._form.close()
async def is_disconnected(self) -> bool:
if not self._is_disconnected:
message: Message = {}
# If message isn't immediately available, move on
with anyio.CancelScope() as cs:
cs.cancel()
message = await self._receive()
if message.get("type") == "http.disconnect":
self._is_disconnected = True
return self._is_disconnected
async def send_push_promise(self, path: str) -> None:
if "http.response.push" in self.scope.get("extensions", {}):
raw_headers: list[tuple[bytes, bytes]] = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})
================================================
FILE: starlette/responses.py
================================================
from __future__ import annotations
import hashlib
import http.cookies
import json
import os
import stat
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, Iterable, Mapping, Sequence
from datetime import datetime
from email.utils import format_datetime, formatdate
from functools import partial
from mimetypes import guess_type
from secrets import token_hex
from typing import Any, Literal
from urllib.parse import quote
import anyio
import anyio.to_thread
from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders
from starlette.requests import ClientDisconnect
from starlette.types import Message, Receive, Scope, Send
class Response:
media_type = None
charset = "utf-8"
def __init__(
self,
content: Any = None,
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.body = self.render(content)
self.init_headers(headers)
def render(self, content: Any) -> bytes | memoryview:
if content is None:
return b""
if isinstance(content, bytes | memoryview):
return content
return content.encode(self.charset) # type: ignore
def init_headers(self, headers: Mapping[str, str] | None = None) -> None:
if headers is None:
raw_headers: list[tuple[bytes, bytes]] = []
populate_content_length = True
populate_content_type = True
else:
raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
keys = [h[0] for h in raw_headers]
populate_content_length = b"content-length" not in keys
populate_content_type = b"content-type" not in keys
body = getattr(self, "body", None)
if (
body is not None
and populate_content_length
and not (self.status_code < 200 or self.status_code in (204, 304))
):
content_length = str(len(body))
raw_headers.append((b"content-length", content_length.encode("latin-1")))
content_type = self.media_type
if content_type is not None and populate_content_type:
if content_type.startswith("text/") and "charset=" not in content_type.lower():
content_type += "; charset=" + self.charset
raw_headers.append((b"content-type", content_type.encode("latin-1")))
self.raw_headers = raw_headers
@property
def headers(self) -> MutableHeaders:
if not hasattr(self, "_headers"):
self._headers = MutableHeaders(raw=self.raw_headers)
return self._headers
def set_cookie(
self,
key: str,
value: str = "",
max_age: int | None = None,
expires: datetime | str | int | None = None,
path: str | None = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
samesite: Literal["lax", "strict", "none"] | None = "lax",
partitioned: bool = False,
) -> None:
cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie()
cookie[key] = value
if max_age is not None:
cookie[key]["max-age"] = max_age
if expires is not None:
if isinstance(expires, datetime):
cookie[key]["expires"] = format_datetime(expires, usegmt=True)
else:
cookie[key]["expires"] = expires
if path is not None:
cookie[key]["path"] = path
if domain is not None:
cookie[key]["domain"] = domain
if secure:
cookie[key]["secure"] = True
if httponly:
cookie[key]["httponly"] = True
if samesite is not None:
assert samesite.lower() in [
"strict",
"lax",
"none",
], "samesite must be either 'strict', 'lax' or 'none'"
cookie[key]["samesite"] = samesite
if partitioned:
if sys.version_info < (3, 14):
raise ValueError("Partitioned cookies are only supported in Python 3.14 and above.") # pragma: no cover
cookie[key]["partitioned"] = True # pragma: no cover
cookie_val = cookie.output(header="").strip()
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
def delete_cookie(
self,
key: str,
path: str = "/",
domain: str | None = None,
secure: bool = False,
httponly: bool = False,
samesite: Literal["lax", "strict", "none"] | None = "lax",
) -> None:
self.set_cookie(
key,
max_age=0,
expires=0,
path=path,
domain=domain,
secure=secure,
httponly=httponly,
samesite=samesite,
)
def _wrap_websocket_denial_send(self, send: Send) -> Send:
async def wrapped(message: Message) -> None:
message_type = message["type"]
if message_type in {"http.response.start", "http.response.body"}: # pragma: no branch
message = {**message, "type": "websocket." + message_type}
await send(message)
return wrapped
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
send = self._wrap_websocket_denial_send(send)
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
await send({"type": "http.response.body", "body": self.body})
if self.background is not None:
await self.background()
class HTMLResponse(Response):
media_type = "text/html"
class PlainTextResponse(Response):
media_type = "text/plain"
class JSONResponse(Response):
media_type = "application/json"
def __init__(
self,
content: Any,
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(content, status_code, headers, media_type, background)
def render(self, content: Any) -> bytes:
return json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
).encode("utf-8")
class RedirectResponse(Response):
def __init__(
self,
url: str | URL,
status_code: int = 307,
headers: Mapping[str, str] | None = None,
background: BackgroundTask | None = None,
) -> None:
super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
Content = str | bytes | memoryview
SyncContentStream = Iterable[Content]
AsyncContentStream = AsyncIterable[Content]
ContentStream = AsyncContentStream | SyncContentStream
class StreamingResponse(Response):
body_iterator: AsyncContentStream
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
if isinstance(content, AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
self.init_headers(headers)
async def listen_for_disconnect(self, receive: Receive) -> None:
while True:
message = await receive()
if message["type"] == "http.disconnect":
break
async def stream_response(self, send: Send) -> None:
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
async for chunk in self.body_iterator:
if not isinstance(chunk, bytes | memoryview):
chunk = chunk.encode(self.charset)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
send = self._wrap_websocket_denial_send(send)
await self.stream_response(send)
if self.background is not None:
await self.background()
return
spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
if spec_version >= (2, 4):
try:
await self.stream_response(send)
except OSError:
raise ClientDisconnect()
else:
with collapse_excgroups():
async with anyio.create_task_group() as task_group:
async def wrap(func: Callable[[], Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()
class MalformedRangeHeader(Exception):
def __init__(self, content: str = "Malformed range header.") -> None:
self.content = content
class RangeNotSatisfiable(Exception):
def __init__(self, max_size: int) -> None:
self.max_size = max_size
class FileResponse(Response):
chunk_size = 64 * 1024
def __init__(
self,
path: str | os.PathLike[str],
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
filename: str | None = None,
stat_result: os.stat_result | None = None,
content_disposition_type: str = "attachment",
) -> None:
self.path = path
self.status_code = status_code
self.filename = filename
if media_type is None:
media_type = guess_type(filename or path)[0] or "text/plain"
self.media_type = media_type
self.background = background
self.init_headers(headers)
self.headers.setdefault("accept-ranges", "bytes")
if self.filename is not None:
content_disposition_filename = quote(self.filename)
if content_disposition_filename != self.filename:
content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
else:
content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
self.headers.setdefault("content-disposition", content_disposition)
self.stat_result = stat_result
if stat_result is not None:
self.set_stat_headers(stat_result)
def set_stat_headers(self, stat_result: os.stat_result) -> None:
content_length = str(stat_result.st_size)
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"'
self.headers.setdefault("content-length", content_length)
self.headers.setdefault("last-modified", last_modified)
self.headers.setdefault("etag", etag)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope_type = scope["type"]
send_header_only = scope_type == "http" and scope["method"].upper() == "HEAD"
send_pathsend = scope_type == "http" and "http.response.pathsend" in scope.get("extensions", {})
if scope_type == "websocket":
send = self._wrap_websocket_denial_send(send)
if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
self.set_stat_headers(stat_result)
except FileNotFoundError:
raise RuntimeError(f"File at path {self.path} does not exist.")
else:
mode = stat_result.st_mode
if not stat.S_ISREG(mode):
raise RuntimeError(f"File at path {self.path} is not a file.")
else:
stat_result = self.stat_result
headers = Headers(scope=scope)
http_range = headers.get("range")
http_if_range = headers.get("if-range")
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
await self._handle_simple(send, send_header_only, send_pathsend)
else:
try:
ranges = self._parse_range_header(http_range, stat_result.st_size)
except MalformedRangeHeader as exc:
return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send)
except RangeNotSatisfiable as exc:
response = PlainTextResponse(status_code=416, headers={"Content-Range": f"bytes */{exc.max_size}"})
return await response(scope, receive, send)
if len(ranges) == 1:
start, end = ranges[0]
await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only)
else:
await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only)
if self.background is not None:
await self.background()
async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None:
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
elif send_pathsend:
await send({"type": "http.response.pathsend", "path": str(self.path)})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
more_body = len(chunk) == self.chunk_size
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
async def _handle_single_range(
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
) -> None:
headers = MutableHeaders(raw=list(self.raw_headers))
headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}"
headers["content-length"] = str(end - start)
await send({"type": "http.response.start", "status": 206, "headers": headers.raw})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
await file.seek(start)
more_body = True
while more_body:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
more_body = len(chunk) == self.chunk_size and start < end
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})
async def _handle_multiple_ranges(
self,
send: Send,
ranges: list[tuple[int, int]],
file_size: int,
send_header_only: bool,
) -> None:
# In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes).
boundary = token_hex(13)
content_length, header_generator = self.generate_multipart(
ranges, boundary, file_size, self.headers["content-type"]
)
headers = MutableHeaders(raw=list(self.raw_headers))
headers["content-type"] = f"multipart/byteranges; boundary={boundary}"
headers["content-length"] = str(content_length)
await send({"type": "http.response.start", "status": 206, "headers": headers.raw})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
for start, end in ranges:
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
await file.seek(start)
while start < end:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"\r\n", "more_body": True})
await send(
{
"type": "http.response.body",
"body": f"--{boundary}--".encode("latin-1"),
"more_body": False,
}
)
def _should_use_range(self, http_if_range: str) -> bool:
return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"]
@classmethod
def _parse_range_header(cls, http_range: str, file_size: int) -> list[tuple[int, int]]:
ranges: list[tuple[int, int]] = []
try:
units, range_ = http_range.split("=", 1)
except ValueError:
raise MalformedRangeHeader()
units = units.strip().lower()
if units != "bytes":
raise MalformedRangeHeader("Only support bytes range")
ranges = cls._parse_ranges(range_, file_size)
if len(ranges) == 0:
raise MalformedRangeHeader("Range header: range must be requested")
if any(not (0 <= start < file_size) for start, _ in ranges):
raise RangeNotSatisfiable(file_size)
if any(start > end for start, end in ranges):
raise MalformedRangeHeader("Range header: start must be less than end")
if len(ranges) == 1:
return ranges
# Merge overlapping ranges
ranges.sort()
result: list[tuple[int, int]] = [ranges[0]]
for start, end in ranges[1:]:
last_start, last_end = result[-1]
if start <= last_end:
result[-1] = (last_start, max(last_end, end))
else:
result.append((start, end))
return result
@classmethod
def _parse_ranges(cls, range_: str, file_size: int) -> list[tuple[int, int]]:
ranges: list[tuple[int, int]] = []
for part in range_.split(","):
part = part.strip()
# If the range is empty or a single dash, we ignore it.
if not part or part == "-":
continue
# If the range is not in the format "start-end", we ignore it.
if "-" not in part:
continue
start_str, end_str = part.split("-", 1)
start_str = start_str.strip()
end_str = end_str.strip()
try:
start = int(start_str) if start_str else file_size - int(end_str)
end = int(end_str) + 1 if start_str and end_str and int(end_str) < file_size else file_size
ranges.append((start, end))
except ValueError:
# If the range is not numeric, we ignore it.
continue
return ranges
def generate_multipart(
self,
ranges: Sequence[tuple[int, int]],
boundary: str,
max_size: int,
content_type: str,
) -> tuple[int, Callable[[int, int], bytes]]:
r"""
Multipart response headers generator.
```
--{boundary}\r\n
Content-Type: {content_type}\r\n
Content-Range: bytes {start}-{end-1}/{max_size}\r\n
\r\n
..........content...........\r\n
--{boundary}\r\n
Content-Type: {content_type}\r\n
Content-Range: bytes {start}-{end-1}/{max_size}\r\n
\r\n
..........content...........\r\n
--{boundary}--
```
"""
boundary_len = len(boundary)
static_header_part_len = 49 + boundary_len + len(content_type) + len(str(max_size))
content_length = sum(
(len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers
+ (end - start) # Content
for start, end in ranges
) + (
4 + boundary_len # --boundary--
)
return (
content_length,
lambda start, end: (
f"""\
--{boundary}\r
Content-Type: {content_type}\r
Content-Range: bytes {start}-{end - 1}/{max_size}\r
\r
"""
).encode("latin-1"),
)
================================================
FILE: starlette/routing.py
================================================
from __future__ import annotations
import contextlib
import functools
import inspect
import re
import traceback
import types
import warnings
from collections.abc import Awaitable, Callable, Collection, Generator, Sequence
from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
from enum import Enum
from re import Pattern
from typing import Any, TypeVar
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import get_route_path, is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose
class NoMatchFound(Exception):
"""
Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
if no matching route exists.
"""
def __init__(self, name: str, path_params: dict[str, Any]) -> None:
params = ", ".join(list(path_params.keys()))
super().__init__(f'No route exists for name "{name}" and params "{params}".')
class Match(Enum):
NONE = 0
PARTIAL = 1
FULL = 2
def request_response(
func: Callable[[Request], Awaitable[Response] | Response],
) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
f: Callable[[Request], Awaitable[Response]] = (
func if is_async_callable(func) else functools.partial(run_in_threadpool, func)
)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive, send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = await f(request)
await response(scope, receive, send)
await wrap_app_handling_exceptions(app, request)(scope, receive, send)
return app
def websocket_session(
func: Callable[[WebSocket], Awaitable[None]],
) -> ASGIApp:
"""
Takes a coroutine `func(session)`, and returns an ASGI application.
"""
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await func(session)
await wrap_app_handling_exceptions(app, session)(scope, receive, send)
return app
def get_name(endpoint: Callable[..., Any]) -> str:
return getattr(endpoint, "__name__", endpoint.__class__.__name__)
def replace_params(
path: str,
param_convertors: dict[str, Convertor[Any]],
path_params: dict[str, str],
) -> tuple[str, dict[str, str]]:
for key, value in list(path_params.items()):
if "{" + key + "}" in path:
convertor = param_convertors[key]
value = convertor.to_string(value)
path = path.replace("{" + key + "}", value)
path_params.pop(key)
return path, path_params
# Match parameters in URL paths, eg. '{param}', and '{param:int}'
PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
def compile_path(
path: str,
) -> tuple[Pattern[str], str, dict[str, Convertor[Any]]]:
"""
Given a path string, like: "/{username:str}",
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
of (regex, format, {param_name:convertor}).
regex: "/(?P[^/]+)"
format: "/{username}"
convertors: {"username": StringConvertor()}
"""
is_host = not path.startswith("/")
path_regex = "^"
path_format = ""
duplicated_params: set[str] = set()
idx = 0
param_convertors = {}
for match in PARAM_REGEX.finditer(path):
param_name, convertor_type = match.groups("str")
convertor_type = convertor_type.lstrip(":")
assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
convertor = CONVERTOR_TYPES[convertor_type]
path_regex += re.escape(path[idx : match.start()])
path_regex += f"(?P<{param_name}>{convertor.regex})"
path_format += path[idx : match.start()]
path_format += "{%s}" % param_name
if param_name in param_convertors:
duplicated_params.add(param_name)
param_convertors[param_name] = convertor
idx = match.end()
if duplicated_params:
names = ", ".join(sorted(duplicated_params))
ending = "s" if len(duplicated_params) > 1 else ""
raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
if is_host:
# Align with `Host.matches()` behavior, which ignores port.
hostname = path[idx:].split(":")[0]
path_regex += re.escape(hostname) + "$"
else:
path_regex += re.escape(path[idx:]) + "$"
path_format += path[idx:]
return re.compile(path_regex), path_format, param_convertors
class BaseRoute:
def matches(self, scope: Scope) -> tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
raise NotImplementedError() # pragma: no cover
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
raise NotImplementedError() # pragma: no cover
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
A route may be used in isolation as a stand-alone ASGI app.
This is a somewhat contrived case, as they'll almost always be used
within a Router, but could be useful for some tooling and minimal apps.
"""
match, child_scope = self.matches(scope)
if match == Match.NONE:
if scope["type"] == "http":
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
elif scope["type"] == "websocket": # pragma: no branch
websocket_close = WebSocketClose()
await websocket_close(scope, receive, send)
return
scope.update(child_scope)
await self.handle(scope, receive, send)
class Route(BaseRoute):
def __init__(
self,
path: str,
endpoint: Callable[..., Any],
*,
methods: Collection[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
middleware: Sequence[Middleware] | None = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.include_in_schema = include_in_schema
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
# Endpoint is function or method. Treat it as `func(request) -> response`.
self.app = request_response(endpoint)
if methods is None:
methods = ["GET"]
else:
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
if methods is None:
self.methods = None
else:
self.methods = {method.upper() for method in methods}
if "GET" in self.methods:
self.methods.add("HEAD")
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, Any]
if scope["type"] == "http":
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"endpoint": self.endpoint, "path_params": path_params}
if self.methods and scope["method"] not in self.methods:
return Match.PARTIAL, child_scope
else:
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="http")
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.methods and scope["method"] not in self.methods:
headers = {"Allow": ", ".join(self.methods)}
if "app" in scope:
raise HTTPException(status_code=405, headers=headers)
else:
response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, Route)
and self.path == other.path
and self.endpoint == other.endpoint
and self.methods == other.methods
)
def __repr__(self) -> str:
class_name = self.__class__.__name__
methods = sorted(self.methods or [])
path, name = self.path, self.name
return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})"
class WebSocketRoute(BaseRoute):
def __init__(
self,
path: str,
endpoint: Callable[..., Any],
*,
name: str | None = None,
middleware: Sequence[Middleware] | None = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
# Endpoint is function or method. Treat it as `func(websocket)`.
self.app = websocket_session(endpoint)
else:
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, Any]
if scope["type"] == "websocket":
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"endpoint": self.endpoint, "path_params": path_params}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
seen_params = set(path_params.keys())
expected_params = set(self.param_convertors.keys())
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="websocket")
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: Any) -> bool:
return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
class Mount(BaseRoute):
def __init__(
self,
path: str,
app: ASGIApp | None = None,
routes: Sequence[BaseRoute] | None = None,
name: str | None = None,
*,
middleware: Sequence[Middleware] | None = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if app is not None:
self._base_app: ASGIApp = app
else:
self._base_app = Router(routes=routes)
self.app = self._base_app
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
self.app = cls(self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
@property
def routes(self) -> list[BaseRoute]:
return getattr(self._base_app, "routes", [])
def matches(self, scope: Scope) -> tuple[Match, Scope]:
path_params: dict[str, Any]
if scope["type"] in ("http", "websocket"): # pragma: no branch
root_path = scope.get("root_path", "")
route_path = get_route_path(scope)
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = route_path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {
"path_params": path_params,
# app_root_path will only be set at the top level scope,
# initialized with the (optional) value of a root_path
# set above/before Starlette. And even though any
# mount will have its own child scope with its own respective
# root_path, the app_root_path will always be available in all
# the child scopes with the same top level value because it's
# set only once here with a default, any other child scope will
# just inherit that app_root_path default value stored in the
# scope. All this is needed to support Request.url_for(), as it
# uses the app_root_path to build the URL path.
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "".
path_params["path"] = path_params["path"].lstrip("/")
path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path)
elif self.name is None or name.startswith(self.name + ":"):
if self.name is None:
# No mount name.
remaining_name = name
else:
# 'name' matches ":".
remaining_name = name[len(self.name) + 1 :]
path_kwarg = path_params.get("path")
path_params["path"] = ""
path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if path_kwarg is not None:
remaining_params["path"] = path_kwarg
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: Any) -> bool:
return isinstance(other, Mount) and self.path == other.path and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
name = self.name or ""
return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})"
class Host(BaseRoute):
def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
assert not host.startswith("/"), "Host must not start with '/'"
self.host = host
self.app = app
self.name = name
self.host_regex, self.host_format, self.param_convertors = compile_path(host)
@property
def routes(self) -> list[BaseRoute]:
return getattr(self.app, "routes", [])
def matches(self, scope: Scope) -> tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"): # pragma:no branch
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "".
path = path_params.pop("path")
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path, host=host)
elif self.name is None or name.startswith(self.name + ":"):
if self.name is None:
# No mount name.
remaining_name = name
else:
# 'name' matches ":".
remaining_name = name[len(self.name) + 1 :]
host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
return URLPath(path=str(url), protocol=url.protocol, host=host)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def __eq__(self, other: Any) -> bool:
return isinstance(other, Host) and self.host == other.host and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
name = self.name or ""
return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
_T = TypeVar("_T")
class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
def __init__(self, cm: AbstractContextManager[_T]):
self._cm = cm
async def __aenter__(self) -> _T:
return self._cm.__enter__()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> bool | None:
return self._cm.__exit__(exc_type, exc_value, traceback)
def _wrap_gen_lifespan_context(
lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
cmgr = contextlib.contextmanager(lifespan_context)
@functools.wraps(cmgr)
def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
return _AsyncLiftContextManager(cmgr(app))
return wrapper
class _DefaultLifespan:
def __init__(self, router: Router):
self._router = router
async def __aenter__(self) -> None:
pass
async def __aexit__(self, *exc_info: object) -> None:
pass
def __call__(self: _T, app: object) -> _T:
return self
class Router:
def __init__(
self,
routes: Sequence[BaseRoute] | None = None,
redirect_slashes: bool = True,
default: ASGIApp | None = None,
# the generic to Lifespan[AppType] is the type of the top level application
# which the router cannot know statically, so we use Any
lifespan: Lifespan[Any] | None = None,
*,
middleware: Sequence[Middleware] | None = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
if lifespan is None:
self.lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"async generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = asynccontextmanager(lifespan)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = _wrap_gen_lifespan_context(lifespan)
else:
self.lifespan_context = lifespan
self.middleware_stack = self.app
if middleware:
for cls, args, kwargs in reversed(middleware):
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
websocket_close = WebSocketClose()
await websocket_close(scope, receive, send)
return
# If we're running inside a starlette application then raise an
# exception, so that the configurable exception handler can deal with
# returning the response. For plain ASGI apps, just return the response.
if "app" in scope:
raise HTTPException(status_code=404)
else:
response = PlainTextResponse("Not Found", status_code=404)
await response(scope, receive, send)
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
for route in self.routes:
try:
return route.url_path_for(name, **path_params)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
started = False
app: Any = scope.get("app")
await receive()
try:
async with self.lifespan_context(app) as maybe_state:
if maybe_state is not None:
if "state" not in scope:
raise RuntimeError('The server does not support "state" in the lifespan scope.')
scope["state"].update(maybe_state)
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
exc_text = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
await send({"type": "lifespan.shutdown.complete"})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
The main entry point to the Router class.
"""
await self.middleware_stack(scope, receive, send)
async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] in ("http", "websocket", "lifespan")
if "router" not in scope:
scope["router"] = self
if scope["type"] == "lifespan":
await self.lifespan(scope, receive, send)
return
partial = None
for route in self.routes:
# Determine if any route matches the incoming scope,
# and hand over to the matching route if found.
match, child_scope = route.matches(scope)
if match == Match.FULL:
scope.update(child_scope)
await route.handle(scope, receive, send)
return
elif match == Match.PARTIAL and partial is None:
partial = route
partial_scope = child_scope
if partial is not None:
# Handle partial matches. These are cases where an endpoint is
# able to handle the request, but is not a preferred option.
# We use this in particular to deal with "405 Method Not Allowed".
scope.update(partial_scope)
await partial.handle(scope, receive, send)
return
route_path = get_route_path(scope)
if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
redirect_scope = dict(scope)
if route_path.endswith("/"):
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
else:
redirect_scope["path"] = redirect_scope["path"] + "/"
for route in self.routes:
match, child_scope = route.matches(redirect_scope)
if match != Match.NONE:
redirect_url = URL(scope=redirect_scope)
response = RedirectResponse(url=str(redirect_url))
await response(scope, receive, send)
return
await self.default(scope, receive, send)
def __eq__(self, other: Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Mount(path, app=app, name=name)
self.routes.append(route)
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Host(host, app=app, name=name)
self.routes.append(route)
def add_route(
self,
path: str,
endpoint: Callable[[Request], Awaitable[Response] | Response],
methods: Collection[str] | None = None,
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
route = Route(
path,
endpoint=endpoint,
methods=methods,
name=name,
include_in_schema=include_in_schema,
)
self.routes.append(route)
def add_websocket_route(
self,
path: str,
endpoint: Callable[[WebSocket], Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
route = WebSocketRoute(path, endpoint=endpoint, name=name)
self.routes.append(route)
================================================
FILE: starlette/schemas.py
================================================
from __future__ import annotations
import inspect
import re
from collections.abc import Callable
from typing import Any, NamedTuple
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Host, Mount, Route
try:
import yaml
except ModuleNotFoundError: # pragma: no cover
yaml = None # type: ignore[assignment]
class OpenAPIResponse(Response):
media_type = "application/vnd.oai.openapi"
def render(self, content: Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
class EndpointInfo(NamedTuple):
path: str
http_method: str
func: Callable[..., Any]
_remove_converter_pattern = re.compile(r":\w+}")
class BaseSchemaGenerator:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
raise NotImplementedError() # pragma: no cover
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
"""
Given the routes, yields the following information:
- path
eg: /users/
- http_method
one of 'get', 'post', 'put', 'patch', 'delete', 'options'
- func
method ready to extract the docstring
"""
endpoints_info: list[EndpointInfo] = []
for route in routes:
if isinstance(route, Mount | Host):
routes = route.routes or []
if isinstance(route, Mount):
path = self._remove_converter(route.path)
else:
path = ""
sub_endpoints = [
EndpointInfo(
path="".join((path, sub_endpoint.path)),
http_method=sub_endpoint.http_method,
func=sub_endpoint.func,
)
for sub_endpoint in self.get_endpoints(routes)
]
endpoints_info.extend(sub_endpoints)
elif not isinstance(route, Route) or not route.include_in_schema:
continue
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
path = self._remove_converter(route.path)
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
if not hasattr(route.endpoint, method):
continue
func = getattr(route.endpoint, method)
endpoints_info.append(EndpointInfo(path, method.lower(), func))
return endpoints_info
def _remove_converter(self, path: str) -> str:
"""
Remove the converter from the path.
For example, a route like this:
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
Should be represented as `/users/{id}` in the OpenAPI schema.
"""
return _remove_converter_pattern.sub("}", path)
def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
docstring = func_or_method.__doc__
if not docstring:
return {}
assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
# We support having regular docstrings before the schema
# definition. Here we return just the schema part from
# the docstring.
docstring = docstring.split("---")[-1]
parsed = yaml.safe_load(docstring)
if not isinstance(parsed, dict):
# A regular docstring (not yaml formatted) can return
# a simple string here, which wouldn't follow the schema.
return {}
return parsed
def OpenAPIResponse(self, request: Request) -> Response:
routes = request.app.routes
schema = self.get_schema(routes=routes)
return OpenAPIResponse(schema)
class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, base_schema: dict[str, Any]) -> None:
self.base_schema = base_schema
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
schema = dict(self.base_schema)
schema.setdefault("paths", {})
endpoints_info = self.get_endpoints(routes)
for endpoint in endpoints_info:
parsed = self.parse_docstring(endpoint.func)
if not parsed:
continue
if endpoint.path not in schema["paths"]:
schema["paths"][endpoint.path] = {}
schema["paths"][endpoint.path][endpoint.http_method] = parsed
return schema
================================================
FILE: starlette/staticfiles.py
================================================
from __future__ import annotations
import errno
import importlib.util
import os
import stat
from email.utils import parsedate
from typing import Union
import anyio
import anyio.to_thread
from starlette._utils import get_route_path
from starlette.datastructures import URL, Headers
from starlette.exceptions import HTTPException
from starlette.responses import FileResponse, RedirectResponse, Response
from starlette.types import Receive, Scope, Send
PathLike = Union[str, "os.PathLike[str]"]
class NotModifiedResponse(Response):
NOT_MODIFIED_HEADERS = (
"cache-control",
"content-location",
"date",
"etag",
"expires",
"vary",
)
def __init__(self, headers: Headers):
super().__init__(
status_code=304,
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
)
class StaticFiles:
def __init__(
self,
*,
directory: PathLike | None = None,
packages: list[str | tuple[str, str]] | None = None,
html: bool = False,
check_dir: bool = True,
follow_symlink: bool = False,
) -> None:
self.directory = directory
self.packages = packages
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
self.follow_symlink = follow_symlink
if check_dir and directory is not None and not os.path.isdir(directory):
raise RuntimeError(f"Directory '{directory}' does not exist")
def get_directories(
self,
directory: PathLike | None = None,
packages: list[str | tuple[str, str]] | None = None,
) -> list[PathLike]:
"""
Given `directory` and `packages` arguments, return a list of all the
directories that should be used for serving static files from.
"""
directories = []
if directory is not None:
directories.append(directory)
for package in packages or []:
if isinstance(package, tuple):
package, statics_dir = package
else:
statics_dir = "statics"
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
assert os.path.isdir(package_directory), (
f"Directory '{statics_dir!r}' in package {package!r} could not be found."
)
directories.append(package_directory)
return directories
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
The ASGI entry point.
"""
assert scope["type"] == "http"
if not self.config_checked:
await self.check_config()
self.config_checked = True
path = self.get_path(scope)
response = await self.get_response(path, scope)
await response(scope, receive, send)
def get_path(self, scope: Scope) -> str:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
route_path = get_route_path(scope)
return os.path.normpath(os.path.join(*route_path.split("/")))
async def get_response(self, path: str, scope: Scope) -> Response:
"""
Returns an HTTP response, given the incoming path, method and request headers.
"""
if scope["method"] not in ("GET", "HEAD"):
raise HTTPException(status_code=405)
try:
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
except PermissionError:
raise HTTPException(status_code=401)
except OSError as exc:
# Filename is too long, so it can't be a valid static file.
if exc.errno == errno.ENAMETOOLONG:
raise HTTPException(status_code=404)
raise exc
except ValueError:
# Null bytes or other invalid characters in the path.
raise HTTPException(status_code=404)
if stat_result and stat.S_ISREG(stat_result.st_mode):
# We have a static file to serve.
return self.file_response(full_path, stat_result, scope)
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
if not scope["path"].endswith("/"):
# Directory URLs should redirect to always end in "/".
url = URL(scope=scope)
url = url.replace(path=url.path + "/")
return RedirectResponse(url=url)
return self.file_response(full_path, stat_result, scope)
if self.html:
# Check for '404.html' if we're in HTML mode.
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
if stat_result and stat.S_ISREG(stat_result.st_mode):
return FileResponse(full_path, stat_result=stat_result, status_code=404)
raise HTTPException(status_code=404)
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
for directory in self.all_directories:
joined_path = os.path.join(directory, path)
if self.follow_symlink:
full_path = os.path.abspath(joined_path)
directory = os.path.abspath(directory)
else:
full_path = os.path.realpath(joined_path)
directory = os.path.realpath(directory)
if os.path.commonpath([full_path, directory]) != str(directory):
# Don't allow misbehaving clients to break out of the static files directory.
continue
try:
return full_path, os.stat(full_path)
except (FileNotFoundError, NotADirectoryError):
continue
return "", None
def file_response(
self,
full_path: PathLike,
stat_result: os.stat_result,
scope: Scope,
status_code: int = 200,
) -> Response:
request_headers = Headers(scope=scope)
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
if self.is_not_modified(response.headers, request_headers):
return NotModifiedResponse(response.headers)
return response
async def check_config(self) -> None:
"""
Perform a one-off configuration check that StaticFiles is actually
pointed at a directory, so that we can raise loud errors rather than
just returning 404 responses.
"""
if self.directory is None:
return
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
except FileNotFoundError:
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
"""
Given the request and response headers, return `True` if an HTTP
"Not Modified" response could be returned instead.
"""
if if_none_match := request_headers.get("if-none-match"):
# The "etag" header is added by FileResponse, so it's always present.
etag = response_headers["etag"]
return etag in [tag.strip(" W/") for tag in if_none_match.split(",")]
try:
if_modified_since = parsedate(request_headers["if-modified-since"])
last_modified = parsedate(response_headers["last-modified"])
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
return True
except KeyError:
pass
return False
================================================
FILE: starlette/status.py
================================================
"""
HTTP codes
See HTTP Status Code Registry:
https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
And RFC 9110 - https://www.rfc-editor.org/rfc/rfc9110
"""
from __future__ import annotations
import warnings
__all__ = [
"HTTP_100_CONTINUE",
"HTTP_101_SWITCHING_PROTOCOLS",
"HTTP_102_PROCESSING",
"HTTP_103_EARLY_HINTS",
"HTTP_200_OK",
"HTTP_201_CREATED",
"HTTP_202_ACCEPTED",
"HTTP_203_NON_AUTHORITATIVE_INFORMATION",
"HTTP_204_NO_CONTENT",
"HTTP_205_RESET_CONTENT",
"HTTP_206_PARTIAL_CONTENT",
"HTTP_207_MULTI_STATUS",
"HTTP_208_ALREADY_REPORTED",
"HTTP_226_IM_USED",
"HTTP_300_MULTIPLE_CHOICES",
"HTTP_301_MOVED_PERMANENTLY",
"HTTP_302_FOUND",
"HTTP_303_SEE_OTHER",
"HTTP_304_NOT_MODIFIED",
"HTTP_305_USE_PROXY",
"HTTP_306_RESERVED",
"HTTP_307_TEMPORARY_REDIRECT",
"HTTP_308_PERMANENT_REDIRECT",
"HTTP_400_BAD_REQUEST",
"HTTP_401_UNAUTHORIZED",
"HTTP_402_PAYMENT_REQUIRED",
"HTTP_403_FORBIDDEN",
"HTTP_404_NOT_FOUND",
"HTTP_405_METHOD_NOT_ALLOWED",
"HTTP_406_NOT_ACCEPTABLE",
"HTTP_407_PROXY_AUTHENTICATION_REQUIRED",
"HTTP_408_REQUEST_TIMEOUT",
"HTTP_409_CONFLICT",
"HTTP_410_GONE",
"HTTP_411_LENGTH_REQUIRED",
"HTTP_412_PRECONDITION_FAILED",
"HTTP_413_CONTENT_TOO_LARGE",
"HTTP_414_URI_TOO_LONG",
"HTTP_415_UNSUPPORTED_MEDIA_TYPE",
"HTTP_416_RANGE_NOT_SATISFIABLE",
"HTTP_417_EXPECTATION_FAILED",
"HTTP_418_IM_A_TEAPOT",
"HTTP_421_MISDIRECTED_REQUEST",
"HTTP_422_UNPROCESSABLE_CONTENT",
"HTTP_423_LOCKED",
"HTTP_424_FAILED_DEPENDENCY",
"HTTP_425_TOO_EARLY",
"HTTP_426_UPGRADE_REQUIRED",
"HTTP_428_PRECONDITION_REQUIRED",
"HTTP_429_TOO_MANY_REQUESTS",
"HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE",
"HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS",
"HTTP_500_INTERNAL_SERVER_ERROR",
"HTTP_501_NOT_IMPLEMENTED",
"HTTP_502_BAD_GATEWAY",
"HTTP_503_SERVICE_UNAVAILABLE",
"HTTP_504_GATEWAY_TIMEOUT",
"HTTP_505_HTTP_VERSION_NOT_SUPPORTED",
"HTTP_506_VARIANT_ALSO_NEGOTIATES",
"HTTP_507_INSUFFICIENT_STORAGE",
"HTTP_508_LOOP_DETECTED",
"HTTP_510_NOT_EXTENDED",
"HTTP_511_NETWORK_AUTHENTICATION_REQUIRED",
"WS_1000_NORMAL_CLOSURE",
"WS_1001_GOING_AWAY",
"WS_1002_PROTOCOL_ERROR",
"WS_1003_UNSUPPORTED_DATA",
"WS_1005_NO_STATUS_RCVD",
"WS_1006_ABNORMAL_CLOSURE",
"WS_1007_INVALID_FRAME_PAYLOAD_DATA",
"WS_1008_POLICY_VIOLATION",
"WS_1009_MESSAGE_TOO_BIG",
"WS_1010_MANDATORY_EXT",
"WS_1011_INTERNAL_ERROR",
"WS_1012_SERVICE_RESTART",
"WS_1013_TRY_AGAIN_LATER",
"WS_1014_BAD_GATEWAY",
"WS_1015_TLS_HANDSHAKE",
]
HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101
HTTP_102_PROCESSING = 102
HTTP_103_EARLY_HINTS = 103
HTTP_200_OK = 200
HTTP_201_CREATED = 201
HTTP_202_ACCEPTED = 202
HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203
HTTP_204_NO_CONTENT = 204
HTTP_205_RESET_CONTENT = 205
HTTP_206_PARTIAL_CONTENT = 206
HTTP_207_MULTI_STATUS = 207
HTTP_208_ALREADY_REPORTED = 208
HTTP_226_IM_USED = 226
HTTP_300_MULTIPLE_CHOICES = 300
HTTP_301_MOVED_PERMANENTLY = 301
HTTP_302_FOUND = 302
HTTP_303_SEE_OTHER = 303
HTTP_304_NOT_MODIFIED = 304
HTTP_305_USE_PROXY = 305
HTTP_306_RESERVED = 306
HTTP_307_TEMPORARY_REDIRECT = 307
HTTP_308_PERMANENT_REDIRECT = 308
HTTP_400_BAD_REQUEST = 400
HTTP_401_UNAUTHORIZED = 401
HTTP_402_PAYMENT_REQUIRED = 402
HTTP_403_FORBIDDEN = 403
HTTP_404_NOT_FOUND = 404
HTTP_405_METHOD_NOT_ALLOWED = 405
HTTP_406_NOT_ACCEPTABLE = 406
HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407
HTTP_408_REQUEST_TIMEOUT = 408
HTTP_409_CONFLICT = 409
HTTP_410_GONE = 410
HTTP_411_LENGTH_REQUIRED = 411
HTTP_412_PRECONDITION_FAILED = 412
HTTP_413_CONTENT_TOO_LARGE = 413
HTTP_414_URI_TOO_LONG = 414
HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415
HTTP_416_RANGE_NOT_SATISFIABLE = 416
HTTP_417_EXPECTATION_FAILED = 417
HTTP_418_IM_A_TEAPOT = 418
HTTP_421_MISDIRECTED_REQUEST = 421
HTTP_422_UNPROCESSABLE_CONTENT = 422
HTTP_423_LOCKED = 423
HTTP_424_FAILED_DEPENDENCY = 424
HTTP_425_TOO_EARLY = 425
HTTP_426_UPGRADE_REQUIRED = 426
HTTP_428_PRECONDITION_REQUIRED = 428
HTTP_429_TOO_MANY_REQUESTS = 429
HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431
HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_501_NOT_IMPLEMENTED = 501
HTTP_502_BAD_GATEWAY = 502
HTTP_503_SERVICE_UNAVAILABLE = 503
HTTP_504_GATEWAY_TIMEOUT = 504
HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505
HTTP_506_VARIANT_ALSO_NEGOTIATES = 506
HTTP_507_INSUFFICIENT_STORAGE = 507
HTTP_508_LOOP_DETECTED = 508
HTTP_510_NOT_EXTENDED = 510
HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511
"""
WebSocket codes
https://www.iana.org/assignments/websocket/websocket.xml#close-code-number
https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent
"""
WS_1000_NORMAL_CLOSURE = 1000
WS_1001_GOING_AWAY = 1001
WS_1002_PROTOCOL_ERROR = 1002
WS_1003_UNSUPPORTED_DATA = 1003
WS_1005_NO_STATUS_RCVD = 1005
WS_1006_ABNORMAL_CLOSURE = 1006
WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007
WS_1008_POLICY_VIOLATION = 1008
WS_1009_MESSAGE_TOO_BIG = 1009
WS_1010_MANDATORY_EXT = 1010
WS_1011_INTERNAL_ERROR = 1011
WS_1012_SERVICE_RESTART = 1012
WS_1013_TRY_AGAIN_LATER = 1013
WS_1014_BAD_GATEWAY = 1014
WS_1015_TLS_HANDSHAKE = 1015
__deprecated__ = {
"HTTP_413_REQUEST_ENTITY_TOO_LARGE": 413,
"HTTP_414_REQUEST_URI_TOO_LONG": 414,
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE": 416,
"HTTP_422_UNPROCESSABLE_ENTITY": 422,
}
def __getattr__(name: str) -> int:
deprecation_changes = {
"HTTP_413_REQUEST_ENTITY_TOO_LARGE": "HTTP_413_CONTENT_TOO_LARGE",
"HTTP_414_REQUEST_URI_TOO_LONG": "HTTP_414_URI_TOO_LONG",
"HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE": "HTTP_416_RANGE_NOT_SATISFIABLE",
"HTTP_422_UNPROCESSABLE_ENTITY": "HTTP_422_UNPROCESSABLE_CONTENT",
}
deprecated = __deprecated__.get(name)
if deprecated:
warnings.warn(
f"'{name}' is deprecated. Use '{deprecation_changes[name]}' instead.",
category=DeprecationWarning,
stacklevel=3,
)
return deprecated
raise AttributeError(f"module 'starlette.status' has no attribute '{name}'")
def __dir__() -> list[str]:
return sorted(list(__all__) + list(__deprecated__.keys())) # pragma: no cover
================================================
FILE: starlette/templating.py
================================================
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
from os import PathLike
from typing import TYPE_CHECKING, Any, overload
from starlette.background import BackgroundTask
from starlette.datastructures import URL
from starlette.requests import Request
from starlette.responses import HTMLResponse
from starlette.types import Receive, Scope, Send
try:
import jinja2
# @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
# hence we try to get pass_context (most installs will be >=3.1)
# and fall back to contextfunction,
# adding a type ignore for mypy to let us access an attribute that may not exist
if TYPE_CHECKING:
pass_context = jinja2.pass_context
else:
if hasattr(jinja2, "pass_context"):
pass_context = jinja2.pass_context
else: # pragma: no cover
pass_context = jinja2.contextfunction # type: ignore[attr-defined]
except ImportError as _import_error: # pragma: no cover
raise ImportError("jinja2 must be installed to use Jinja2Templates") from _import_error
class _TemplateResponse(HTMLResponse):
def __init__(
self,
template: Any,
context: dict[str, Any],
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
):
self.template = template
self.context = context
content = template.render(context)
super().__init__(content, status_code, headers, media_type, background)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = self.context.get("request", {})
extensions = request.get("extensions", {})
if "http.response.debug" in extensions: # pragma: no branch
await send({"type": "http.response.debug", "info": {"template": self.template, "context": self.context}})
await super().__call__(scope, receive, send)
class Jinja2Templates:
"""Jinja2 template renderer.
Example:
```python
from starlette.templating import Jinja2Templates
templates = Jinja2Templates(directory="templates")
async def homepage(request: Request) -> Response:
return templates.TemplateResponse(request, "index.html")
```
"""
@overload
def __init__(
self,
directory: str | PathLike[str] | Sequence[str | PathLike[str]],
*,
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
) -> None: ...
@overload
def __init__(
self,
*,
env: jinja2.Environment,
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
) -> None: ...
def __init__(
self,
directory: str | PathLike[str] | Sequence[str | PathLike[str]] | None = None,
*,
context_processors: list[Callable[[Request], dict[str, Any]]] | None = None,
env: jinja2.Environment | None = None,
) -> None:
assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
self.context_processors = context_processors or []
if directory is not None:
loader = jinja2.FileSystemLoader(directory)
self.env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape())
elif env is not None: # pragma: no branch
self.env = env
self._setup_env_defaults(self.env)
def _setup_env_defaults(self, env: jinja2.Environment) -> None:
@pass_context
def url_for(
context: dict[str, Any],
name: str,
/,
**path_params: Any,
) -> URL:
request: Request = context["request"]
return request.url_for(name, **path_params)
env.globals.setdefault("url_for", url_for)
def get_template(self, name: str) -> jinja2.Template:
return self.env.get_template(name)
def TemplateResponse(
self,
request: Request,
name: str,
context: dict[str, Any] | None = None,
status_code: int = 200,
headers: Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> _TemplateResponse:
"""
Render a template and return an HTML response.
Args:
request: The incoming request instance.
name: The template file name to render.
context: Variables to pass to the template.
status_code: HTTP status code for the response.
headers: Additional headers to include in the response.
media_type: Media type for the response.
background: Background task to run after response is sent.
Returns:
An HTML response with the rendered template content.
"""
context = context or {}
context.setdefault("request", request)
for context_processor in self.context_processors:
context.update(context_processor(request))
template = self.get_template(name)
return _TemplateResponse(
template,
context,
status_code=status_code,
headers=headers,
media_type=media_type,
background=background,
)
================================================
FILE: starlette/testclient.py
================================================
from __future__ import annotations
import contextlib
import inspect
import io
import json
import math
import sys
import warnings
from collections.abc import Awaitable, Callable, Generator, Iterable, Mapping, MutableMapping, Sequence
from concurrent.futures import Future
from contextlib import AbstractContextManager
from types import GeneratorType
from typing import (
Any,
Literal,
TypedDict,
TypeGuard,
cast,
)
from urllib.parse import unquote, urljoin
import anyio
import anyio.abc
import anyio.from_thread
from anyio.streams.stapled import StapledObjectStream
from starlette._utils import is_async_callable
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
if sys.version_info >= (3, 11): # pragma: no cover
from typing import Self
else: # pragma: no cover
from typing_extensions import Self
try:
import httpx
except ModuleNotFoundError: # pragma: no cover
raise RuntimeError(
"The starlette.testclient module requires the httpx package to be installed.\n"
"You can install this with:\n"
" $ pip install httpx\n"
)
_PortalFactoryType = Callable[[], AbstractContextManager[anyio.abc.BlockingPortal]]
ASGIInstance = Callable[[Receive, Send], Awaitable[None]]
ASGI2App = Callable[[Scope], ASGIInstance]
ASGI3App = Callable[[Scope, Receive, Send], Awaitable[None]]
_RequestData = Mapping[str, str | Iterable[str] | bytes]
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
if inspect.isclass(app):
return hasattr(app, "__await__")
return is_async_callable(app)
class _WrapASGI2:
"""
Provide an ASGI3 interface onto an ASGI2 app.
"""
def __init__(self, app: ASGI2App) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
instance = self.app(scope)
await instance(receive, send)
class _AsyncBackend(TypedDict):
backend: str
backend_options: dict[str, Any]
class _Upgrade(Exception):
def __init__(self, session: WebSocketTestSession) -> None:
self.session = session
class WebSocketDenialResponse( # type: ignore[misc]
httpx.Response,
WebSocketDisconnect,
):
"""
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
`WebSocket` is closed before being accepted with a `send_denial_response()`.
"""
class WebSocketTestSession:
def __init__(
self,
app: ASGI3App,
scope: Scope,
portal_factory: _PortalFactoryType,
) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self.extra_headers = None
def __enter__(self) -> WebSocketTestSession:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())
fut, cs = portal.start_task(self._run)
stack.callback(fut.result)
stack.callback(portal.call, cs.cancel)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self
def __exit__(self, *args: Any) -> bool | None:
return self.exit_stack.__exit__(*args)
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""
send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
send_tx, send_rx = send
receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf)
receive_tx, receive_rx = receive
with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs:
self._receive_tx = receive_tx
self._send_rx = send_rx
task_status.started(cs)
await self.app(self.scope, receive_rx.receive, send_tx.send)
# wait for cs.cancel to be called before closing streams
await anyio.sleep_forever()
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
body: list[bytes] = []
while True:
message = self.receive()
assert message["type"] == "websocket.http.response.body"
body.append(message["body"])
if not message.get("more_body", False):
break
raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body))
def send(self, message: Message) -> None:
self.portal.call(self._receive_tx.send, message)
def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
self.send({"type": "websocket.receive", "text": text})
else:
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
def close(self, code: int = 1000, reason: str | None = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
def receive(self) -> Message:
return self.portal.call(self._send_rx.receive)
def receive_text(self) -> str:
message = self.receive()
self._raise_on_close(message)
return cast(str, message["text"])
def receive_bytes(self) -> bytes:
message = self.receive()
self._raise_on_close(message)
return cast(bytes, message["bytes"])
def receive_json(self, mode: Literal["text", "binary"] = "text") -> Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
return json.loads(text)
class _TestClientTransport(httpx.BaseTransport):
def __init__(
self,
app: ASGI3App,
portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
*,
client: tuple[str, int],
app_state: dict[str, Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.portal_factory = portal_factory
self.app_state = app_state
self.client = client
def handle_request(self, request: httpx.Request) -> httpx.Response:
scheme = request.url.scheme
netloc = request.url.netloc.decode(encoding="ascii")
path = request.url.path
raw_path = request.url.raw_path
query = request.url.query.decode(encoding="ascii")
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
if ":" in netloc:
host, port_string = netloc.split(":", 1)
port = int(port_string)
else:
host = netloc
port = default_port
# Include the 'host' header.
if "host" in request.headers:
headers: list[tuple[bytes, bytes]] = []
elif port == default_port: # pragma: no cover
headers = [(b"host", host.encode())]
else: # pragma: no cover
headers = [(b"host", (f"{host}:{port}").encode())]
# Include other request headers.
headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
scope: dict[str, Any]
if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
if subprotocol is None:
subprotocols: Sequence[str] = []
else:
subprotocols = [value.strip() for value in subprotocol.split(",")]
scope = {
"type": "websocket",
"path": unquote(path),
"raw_path": raw_path.split(b"?", 1)[0],
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": self.client,
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
scope = {
"type": "http",
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"raw_path": raw_path.split(b"?", 1)[0],
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
"client": self.client,
"server": [host, port],
"extensions": {"http.response.debug": {}},
"state": self.app_state.copy(),
}
request_complete = False
response_started = False
response_complete: anyio.Event
raw_kwargs: dict[str, Any] = {"stream": io.BytesIO()}
template = None
context = None
async def receive() -> Message:
nonlocal request_complete
if request_complete:
if not response_complete.is_set():
await response_complete.wait()
return {"type": "http.disconnect"}
body = request.read()
if isinstance(body, str):
body_bytes: bytes = body.encode("utf-8") # pragma: no cover
elif body is None:
body_bytes = b"" # pragma: no cover
elif isinstance(body, GeneratorType):
try: # pragma: no cover
chunk = body.send(None)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
return {"type": "http.request", "body": chunk, "more_body": True}
except StopIteration: # pragma: no cover
request_complete = True
return {"type": "http.request", "body": b""}
else:
body_bytes = body
request_complete = True
return {"type": "http.request", "body": body_bytes}
async def send(message: Message) -> None:
nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
assert not response_started, 'Received multiple "http.response.start" messages.'
raw_kwargs["status_code"] = message["status"]
raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
response_started = True
elif message["type"] == "http.response.body":
assert response_started, 'Received "http.response.body" without "http.response.start".'
assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
raw_kwargs["stream"].write(body)
if not more_body:
raw_kwargs["stream"].seek(0)
response_complete.set()
elif message["type"] == "http.response.debug":
template = message["info"]["template"]
context = message["info"]["context"]
try:
with self.portal_factory() as portal:
response_complete = portal.call(anyio.Event)
portal.call(self.app, scope, receive, send)
except BaseException as exc:
if self.raise_server_exceptions:
raise exc
if self.raise_server_exceptions:
assert response_started, "TestClient did not receive any response."
elif not response_started:
raw_kwargs = {
"status_code": 500,
"headers": [],
"stream": io.BytesIO(),
}
raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
response = httpx.Response(**raw_kwargs, request=request)
if template is not None:
response.template = template # type: ignore[attr-defined]
response.context = context # type: ignore[attr-defined]
return response
class TestClient(httpx.Client):
__test__ = False
task: Future[None]
portal: anyio.abc.BlockingPortal | None = None
def __init__(
self,
app: ASGIApp,
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: Literal["asyncio", "trio"] = "asyncio",
backend_options: dict[str, Any] | None = None,
cookies: httpx._types.CookieTypes | None = None,
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
client: tuple[str, int] = ("testclient", 50000),
) -> None:
self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
if _is_asgi3(app):
asgi_app = app
else:
app = cast(ASGI2App, app) # type: ignore[assignment]
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
self.app = asgi_app
self.app_state: dict[str, Any] = {}
transport = _TestClientTransport(
self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
app_state=self.app_state,
client=client,
)
if headers is None:
headers = {}
headers.setdefault("user-agent", "testclient")
super().__init__(
base_url=base_url,
headers=headers,
transport=transport,
follow_redirects=follow_redirects,
cookies=cookies,
)
@contextlib.contextmanager
def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]:
if self.portal is not None:
yield self.portal
else:
with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
yield portal
def request( # type: ignore[override]
self,
method: str,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
if timeout is not httpx.USE_CLIENT_DEFAULT:
warnings.warn(
"You should not use the 'timeout' argument with the TestClient. "
"See https://github.com/Kludex/starlette/issues/1108 for more information.",
DeprecationWarning,
)
url = self._merge_url(url)
return super().request(
method,
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def get( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().get(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def options( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().options(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def head( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().head(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def post( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().post(
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def put( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().put(
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def patch( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
content: httpx._types.RequestContent | None = None,
data: _RequestData | None = None,
files: httpx._types.RequestFiles | None = None,
json: Any = None,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().patch(
url,
content=content,
data=data,
files=files,
json=json,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def delete( # type: ignore[override]
self,
url: httpx._types.URLTypes,
*,
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, Any] | None = None,
) -> httpx.Response:
return super().delete(
url,
params=params,
headers=headers,
cookies=cookies,
auth=auth,
follow_redirects=follow_redirects,
timeout=timeout,
extensions=extensions,
)
def websocket_connect(
self,
url: str,
subprotocols: Sequence[str] | None = None,
**kwargs: Any,
) -> WebSocketTestSession:
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
headers.setdefault("sec-websocket-key", "testserver==")
headers.setdefault("sec-websocket-version", "13")
if subprotocols is not None:
headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
kwargs["headers"] = headers
try:
super().request("GET", url, **kwargs)
except _Upgrade as exc:
session = exc.session
else:
raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
return session
def __enter__(self) -> Self:
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
@stack.callback
def reset_portal() -> None:
self.portal = None
send: anyio.create_memory_object_stream[MutableMapping[str, Any] | None] = (
anyio.create_memory_object_stream(math.inf)
)
receive: anyio.create_memory_object_stream[MutableMapping[str, Any]] = anyio.create_memory_object_stream(
math.inf
)
for channel in (*send, *receive):
stack.callback(channel.close)
self.stream_send = StapledObjectStream(*send)
self.stream_receive = StapledObjectStream(*receive)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)
@stack.callback
def wait_shutdown() -> None:
portal.call(self.wait_shutdown)
self.exit_stack = stack.pop_all()
return self
def __exit__(self, *args: Any) -> None:
self.exit_stack.close()
async def lifespan(self) -> None:
scope = {"type": "lifespan", "state": self.app_state}
try:
await self.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
await self.stream_send.send(None)
async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
async def receive() -> Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message
message = await receive()
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
await receive()
async def wait_shutdown(self) -> None:
async def receive() -> Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
================================================
FILE: starlette/types.py
================================================
from collections.abc import Awaitable, Callable, Mapping, MutableMapping
from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Any, TypeVar
if TYPE_CHECKING:
from starlette.requests import Request
from starlette.responses import Response
from starlette.websockets import WebSocket
AppType = TypeVar("AppType")
Scope = MutableMapping[str, Any]
Message = MutableMapping[str, Any]
Receive = Callable[[], Awaitable[Message]]
Send = Callable[[Message], Awaitable[None]]
ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
StatelessLifespan = Callable[[AppType], AbstractAsyncContextManager[None]]
StatefulLifespan = Callable[[AppType], AbstractAsyncContextManager[Mapping[str, Any]]]
Lifespan = StatelessLifespan[AppType] | StatefulLifespan[AppType]
HTTPExceptionHandler = Callable[["Request", Exception], "Response | Awaitable[Response]"]
WebSocketExceptionHandler = Callable[["WebSocket", Exception], Awaitable[None]]
ExceptionHandler = HTTPExceptionHandler | WebSocketExceptionHandler
================================================
FILE: starlette/websockets.py
================================================
from __future__ import annotations
import enum
import json
from collections.abc import AsyncIterator, Iterable
from typing import Any, cast
from starlette.requests import HTTPConnection, StateT
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send
class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
RESPONSE = 3
class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
class WebSocket(HTTPConnection[StateT]):
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
super().__init__(scope)
assert scope["type"] == "websocket"
self._receive = receive
self._send = send
self.client_state = WebSocketState.CONNECTING
self.application_state = WebSocketState.CONNECTING
async def receive(self) -> Message:
"""
Receive ASGI websocket messages, ensuring valid state transitions.
"""
if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
message = await self._receive()
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
async def send(self, message: Message) -> None:
"""
Send ASGI websocket messages, ensuring valid state transitions.
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
raise RuntimeError(
'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
elif self.application_state == WebSocketState.CONNECTED:
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
try:
await self._send(message)
except OSError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
async def accept(
self,
subprotocol: str | None = None,
headers: Iterable[tuple[bytes, bytes]] | None = None,
) -> None:
headers = headers or []
if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
raise WebSocketDisconnect(message["code"], message.get("reason"))
async def receive_text(self) -> str:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return cast(bytes, message["bytes"])
async def receive_json(self, mode: str = "text") -> Any:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
if mode == "text":
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
return json.loads(text)
async def iter_text(self) -> AsyncIterator[str]:
try:
while True:
yield await self.receive_text()
except WebSocketDisconnect:
pass
async def iter_bytes(self) -> AsyncIterator[bytes]:
try:
while True:
yield await self.receive_bytes()
except WebSocketDisconnect:
pass
async def iter_json(self) -> AsyncIterator[Any]:
try:
while True:
yield await self.receive_json()
except WebSocketDisconnect:
pass
async def send_text(self, data: str) -> None:
await self.send({"type": "websocket.send", "text": data})
async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})
async def send_json(self, data: Any, mode: str = "text") -> None:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(self, code: int = 1000, reason: str | None = None) -> None:
await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
class WebSocketClose:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "websocket.close", "code": self.code, "reason": self.reason})
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/conftest.py
================================================
from __future__ import annotations
import functools
from typing import Any, Literal
import pytest
from starlette.testclient import TestClient
from tests.types import TestClientFactory
@pytest.fixture
def test_client_factory(
anyio_backend_name: Literal["asyncio", "trio"],
anyio_backend_options: dict[str, Any],
) -> TestClientFactory:
# anyio_backend_name defined by:
# https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
return functools.partial(
TestClient,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)
================================================
FILE: tests/middleware/__init__.py
================================================
================================================
FILE: tests/middleware/test_base.py
================================================
from __future__ import annotations
import contextvars
from collections.abc import AsyncGenerator, AsyncIterator, Generator
from contextlib import AsyncExitStack
from pathlib import Path
from typing import Any
import anyio
import pytest
from anyio.abc import TaskStatus
from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import ClientDisconnect, Request
from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket
from tests.types import TestClientFactory
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
response = await call_next(request)
response.headers["Custom-Header"] = "Example"
return response
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")
def exc(request: Request) -> None:
raise Exception("Exc")
def exc_stream(request: Request) -> StreamingResponse:
return StreamingResponse(_generate_faulty_stream())
def _generate_faulty_stream() -> Generator[bytes, None, None]:
yield b"Ok"
raise Exception("Faulty Stream")
class NoResponse:
def __init__(
self,
scope: Scope,
receive: Receive,
send: Send,
):
pass
def __await__(self) -> Generator[Any, None, None]:
return self.dispatch().__await__()
async def dispatch(self) -> None:
pass
async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
app = Starlette(
routes=[
Route("/", endpoint=homepage),
Route("/exc", endpoint=exc),
Route("/exc-stream", endpoint=exc_stream),
Route("/no-response", endpoint=NoResponse),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
],
middleware=[Middleware(CustomMiddleware)],
)
def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
with pytest.raises(Exception) as ctx:
response = client.get("/exc")
assert str(ctx.value) == "Exc"
with pytest.raises(Exception) as ctx:
response = client.get("/exc-stream")
assert str(ctx.value) == "Faulty Stream"
with pytest.raises(RuntimeError):
response = client.get("/no-response")
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
def test_state_data_across_multiple_middlewares(
test_client_factory: TestClientFactory,
) -> None:
expected_value1 = "foo"
expected_value2 = "bar"
class aMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
request.state.foo = expected_value1
response = await call_next(request)
return response
class bMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
request.state.bar = expected_value2
response = await call_next(request)
response.headers["X-State-Foo"] = request.state.foo
return response
class cMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
response = await call_next(request)
response.headers["X-State-Bar"] = request.state.bar
return response
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("OK")
app = Starlette(
routes=[Route("/", homepage)],
middleware=[
Middleware(aMiddleware),
Middleware(bMiddleware),
Middleware(cMiddleware),
],
)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"
assert response.headers["X-State-Foo"] == expected_value1
assert response.headers["X-State-Bar"] == expected_value2
def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")
app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)])
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None:
# Test for https://github.com/Kludex/starlette/issues/1022
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> PlainTextResponse:
await call_next(request)
return PlainTextResponse("Custom")
app = Starlette(middleware=[Middleware(CustomMiddleware)])
client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"
ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
class CustomMiddlewareWithoutBaseHTTPMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
ctxvar.set("set by middleware")
await self.app(scope, receive, send)
assert ctxvar.get() == "set by endpoint"
class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
ctxvar.set("set by middleware")
resp = await call_next(request)
assert ctxvar.get() == "set by endpoint"
return resp # pragma: no cover
@pytest.mark.parametrize(
"middleware_cls",
[
CustomMiddlewareWithoutBaseHTTPMiddleware,
pytest.param(
CustomMiddlewareUsingBaseHTTPMiddleware,
marks=pytest.mark.xfail(
reason=(
"BaseHTTPMiddleware creates a TaskGroup which copies the context"
"and erases any changes to it made within the TaskGroup"
),
raises=AssertionError,
),
),
],
)
def test_contextvars(
test_client_factory: TestClientFactory,
middleware_cls: _MiddlewareFactory[Any],
) -> None:
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
async def homepage(request: Request) -> PlainTextResponse:
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")
app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)])
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content
@pytest.mark.anyio
async def test_run_background_tasks_even_if_client_disconnects() -> None:
# test for https://github.com/Kludex/starlette/issues/1438
response_complete = anyio.Event()
background_task_run = anyio.Event()
async def sleep_and_set() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
background_task_run.set()
async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
return PlainTextResponse(background=BackgroundTask(sleep_and_set))
async def passthrough(
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
return await call_next(request)
app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_background_task)],
)
scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}
async def receive() -> Message:
raise NotImplementedError("Should not be called!")
async def send(message: Message) -> None:
if message["type"] == "http.response.body":
if not message.get("more_body", False): # pragma: no branch
response_complete.set()
await app(scope, receive, send)
assert background_task_run.is_set()
def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFactory) -> None:
# test for https://github.com/Kludex/starlette/issues/2625
async def sleep_and_set() -> None:
await anyio.sleep(0.1)
raise ValueError("TEST")
async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
return PlainTextResponse(background=BackgroundTask(sleep_and_set))
async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_background_task)],
)
client = test_client_factory(app)
with pytest.raises(ValueError, match="TEST"):
client.get("/")
def test_exception_can_be_caught(test_client_factory: TestClientFactory) -> None:
async def error_endpoint(_: Request) -> None:
raise ValueError("TEST")
async def catches_error(request: Request, call_next: RequestResponseEndpoint) -> Response:
try:
return await call_next(request)
except ValueError as exc:
return PlainTextResponse(content=str(exc), status_code=400)
app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=catches_error)],
routes=[Route("/", error_endpoint)],
)
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 400
assert response.text == "TEST"
@pytest.mark.anyio
async def test_do_not_block_on_background_tasks() -> None:
response_complete = anyio.Event()
events: list[str | Message] = []
async def sleep_and_set() -> None:
events.append("Background task started")
await anyio.sleep(0.1)
events.append("Background task finished")
async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
return PlainTextResponse(content="Hello", background=BackgroundTask(sleep_and_set))
async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_background_task)],
)
scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}
async def receive() -> Message:
raise NotImplementedError("Should not be called!")
async def send(message: Message) -> None:
if message["type"] == "http.response.body":
events.append(message)
if not message.get("more_body", False):
response_complete.set()
async with anyio.create_task_group() as tg:
tg.start_soon(app, scope, receive, send)
tg.start_soon(app, scope, receive, send)
# Without the fix, the background tasks would start and finish before the
# last http.response.body is sent.
assert events == [
{"body": b"Hello", "more_body": True, "type": "http.response.body"},
{"body": b"", "more_body": False, "type": "http.response.body"},
{"body": b"Hello", "more_body": True, "type": "http.response.body"},
{"body": b"", "more_body": False, "type": "http.response.body"},
"Background task started",
"Background task started",
"Background task finished",
"Background task finished",
]
@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
# test for https://github.com/Kludex/starlette/issues/1678#issuecomment-1172916042
response_complete = anyio.Event()
context_manager_exited = anyio.Event()
async def sleep_and_set() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
context_manager_exited.set()
class ContextManagerMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with AsyncExitStack() as stack:
stack.push_async_callback(sleep_and_set)
await self.app(scope, receive, send)
async def simple_endpoint(_: Request) -> PlainTextResponse:
return PlainTextResponse(background=BackgroundTask(sleep_and_set))
async def passthrough(
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
return await call_next(request)
app = Starlette(
middleware=[
Middleware(BaseHTTPMiddleware, dispatch=passthrough),
Middleware(ContextManagerMiddleware),
],
routes=[Route("/", simple_endpoint)],
)
scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}
async def receive() -> Message:
raise NotImplementedError("Should not be called!")
async def send(message: Message) -> None:
if message["type"] == "http.response.body":
if not message.get("more_body", False): # pragma: no branch
response_complete.set()
await app(scope, receive, send)
assert context_manager_exited.is_set()
def test_app_receives_http_disconnect_while_sending_if_discarded(
test_client_factory: TestClientFactory,
) -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: Any,
) -> PlainTextResponse:
# As a matter of ordering, this test targets the case where the downstream
# app response is discarded while it is sending a response body.
# We need to wait for the downstream app to begin sending a response body
# before sending the middleware response that will overwrite the downstream
# response.
downstream_app_response = await call_next(request)
body_generator = downstream_app_response.body_iterator
try:
await body_generator.__anext__()
finally:
await body_generator.aclose()
return PlainTextResponse("Custom")
async def downstream_app(
scope: Scope,
receive: Receive,
send: Send,
) -> None:
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
],
}
)
async with anyio.create_task_group() as task_group:
async def cancel_on_disconnect(
*,
task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
) -> None:
task_status.started()
while True:
message = await receive()
if message["type"] == "http.disconnect": # pragma: no branch
task_group.cancel_scope.cancel()
break
# Using start instead of start_soon to ensure that
# cancel_on_disconnect is scheduled by the event loop
# before we start returning the body
await task_group.start(cancel_on_disconnect)
# A timeout is set for 0.1 second in order to ensure that
# we never deadlock the test run in an infinite loop
with anyio.move_on_after(0.1):
while True:
await send(
{
"type": "http.response.body",
"body": b"chunk ",
"more_body": True,
}
)
pytest.fail("http.disconnect should have been received and canceled the scope") # pragma: no cover
app = DiscardingMiddleware(downstream_app)
client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"
def test_app_receives_http_disconnect_after_sending_if_discarded(
test_client_factory: TestClientFactory,
) -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> PlainTextResponse:
await call_next(request)
return PlainTextResponse("Custom")
async def downstream_app(
scope: Scope,
receive: Receive,
send: Send,
) -> None:
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
],
}
)
await send(
{
"type": "http.response.body",
"body": b"first chunk, ",
"more_body": True,
}
)
await send(
{
"type": "http.response.body",
"body": b"second chunk",
"more_body": True,
}
)
message = await receive()
assert message["type"] == "http.disconnect"
app = DiscardingMiddleware(downstream_app)
client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"
def test_read_request_stream_in_app_after_middleware_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return await call_next(request)
app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)
client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200
def test_read_request_stream_in_app_after_middleware_calls_body(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
assert await request.body() == b"a"
return await call_next(request)
app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)
client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200
def test_read_request_body_in_app_after_middleware_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b""
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return await call_next(request)
app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)
client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200
def test_read_request_body_in_app_after_middleware_calls_body(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
assert await request.body() == b"a"
return await call_next(request)
app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)
client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200
def test_read_request_stream_in_dispatch_after_app_calls_stream(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
async for chunk in request.stream():
assert chunk == expected.pop(0)
assert expected == []
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
resp = await call_next(request)
with pytest.raises(RuntimeError, match="Stream consumed"):
async for _ in request.stream():
raise AssertionError("should not be called") # pragma: no cover
return resp
app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)
client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200
def test_read_request_stream_in_dispatch_after_app_calls_body(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
resp = await call_next(request)
with pytest.raises(RuntimeError, match="Stream consumed"):
async for _ in request.stream():
raise AssertionError("should not be called") # pragma: no cover
return resp
app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)
client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200
@pytest.mark.anyio
async def test_read_request_stream_in_dispatch_wrapping_app_calls_body() -> None:
async def endpoint(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
async for chunk in request.stream(): # pragma: no branch
assert chunk == b"2"
break
await Response()(scope, receive, send)
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
expected = b"1"
response: Response | None = None
async for chunk in request.stream(): # pragma: no branch
assert chunk == expected
if expected == b"1":
response = await call_next(request)
expected = b"3"
else:
break
assert response is not None
return response
async def rcv() -> AsyncGenerator[Message, None]:
yield {"type": "http.request", "body": b"1", "more_body": True}
yield {"type": "http.request", "body": b"2", "more_body": True}
yield {"type": "http.request", "body": b"3"}
raise AssertionError( # pragma: no cover
"Should not be called, no need to poll for disconnect"
)
sent: list[Message] = []
async def send(msg: Message) -> None:
sent.append(msg)
app: ASGIApp = endpoint
app = ConsumingMiddleware(app)
rcv_stream = rcv()
await app({"type": "http"}, rcv_stream.__anext__, send)
assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"0")],
},
{"type": "http.response.body", "body": b"", "more_body": False},
]
await rcv_stream.aclose()
def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
assert await request.body() == b"a" # this buffers the request body in memory
resp = await call_next(request)
async for chunk in request.stream():
if chunk:
assert chunk == b"a"
return resp
app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)
client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200
def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
return PlainTextResponse("Homepage")
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
assert await request.body() == b"a" # this buffers the request body in memory
resp = await call_next(request)
assert await request.body() == b"a" # no problem here
return resp
app = Starlette(
routes=[Route("/", homepage, methods=["POST"])],
middleware=[Middleware(ConsumingMiddleware)],
)
client: TestClient = test_client_factory(app)
response = client.post("/", content=b"a")
assert response.status_code == 200
@pytest.mark.anyio
async def test_read_request_disconnected_client() -> None:
"""If we receive a disconnect message when the downstream ASGI
app calls receive() the Request instance passed into the dispatch function
should get marked as disconnected.
The downstream ASGI app should not get a ClientDisconnect raised,
instead if should just receive the disconnect message.
"""
async def endpoint(scope: Scope, receive: Receive, send: Send) -> None:
msg = await receive()
assert msg["type"] == "http.disconnect"
await Response()(scope, receive, send)
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
response = await call_next(request)
disconnected = await request.is_disconnected()
assert disconnected is True
return response
scope = {"type": "http", "method": "POST", "path": "/"}
async def receive() -> AsyncGenerator[Message, None]:
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover
async def send(msg: Message) -> None:
if msg["type"] == "http.response.start":
assert msg["status"] == 200
app: ASGIApp = ConsumingMiddleware(endpoint)
rcv = receive()
await app(scope, rcv.__anext__, send)
await rcv.aclose()
@pytest.mark.anyio
async def test_read_request_disconnected_after_consuming_steam() -> None:
async def endpoint(scope: Scope, receive: Receive, send: Send) -> None:
msg = await receive()
assert msg.pop("more_body", False) is False
assert msg == {"type": "http.request", "body": b"hi"}
msg = await receive()
assert msg == {"type": "http.disconnect"}
await Response()(scope, receive, send)
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
await request.body()
disconnected = await request.is_disconnected()
assert disconnected is True
response = await call_next(request)
return response
scope = {"type": "http", "method": "POST", "path": "/"}
async def receive() -> AsyncGenerator[Message, None]:
yield {"type": "http.request", "body": b"hi"}
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover
async def send(msg: Message) -> None:
if msg["type"] == "http.response.start":
assert msg["status"] == 200
app: ASGIApp = ConsumingMiddleware(endpoint)
rcv = receive()
await app(scope, rcv.__anext__, send)
await rcv.aclose()
def test_downstream_middleware_modifies_receive(
test_client_factory: TestClientFactory,
) -> None:
"""If a downstream middleware modifies receive() the final ASGI app
should see the modified version.
"""
async def endpoint(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
assert body == b"foo foo "
await Response()(scope, receive, send)
class ConsumingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
body = await request.body()
assert body == b"foo "
return await call_next(request)
def modifying_middleware(app: ASGIApp) -> ASGIApp:
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
async def wrapped_receive() -> Message:
msg = await receive()
if msg["type"] == "http.request": # pragma: no branch
msg["body"] = msg["body"] * 2
return msg
await app(scope, wrapped_receive, send)
return wrapped_app
client = test_client_factory(ConsumingMiddleware(modifying_middleware(endpoint)))
resp = client.post("/", content=b"foo ")
assert resp.status_code == 200
def test_pr_1519_comment_1236166180_example() -> None:
"""
https://github.com/Kludex/starlette/pull/1519#issuecomment-1236166180
"""
bodies: list[bytes] = []
class LogRequestBodySize(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
print(len(await request.body()))
return await call_next(request)
def replace_body_middleware(app: ASGIApp) -> ASGIApp:
async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
async def wrapped_rcv() -> Message:
msg = await receive()
msg["body"] += b"-foo"
return msg
await app(scope, wrapped_rcv, send)
return wrapped_app
async def endpoint(request: Request) -> Response:
body = await request.body()
bodies.append(body)
return Response()
app: ASGIApp = Starlette(routes=[Route("/", endpoint, methods=["POST"])])
app = replace_body_middleware(app)
app = LogRequestBodySize(app)
client = TestClient(app)
resp = client.post("/", content=b"Hello, World!")
resp.raise_for_status()
assert bodies == [b"Hello, World!-foo"]
@pytest.mark.anyio
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
"""
Tests for:
- https://github.com/Kludex/starlette/issues/2516
- https://github.com/Kludex/starlette/pull/2687
"""
ordered_events: list[str] = []
unordered_events: list[str] = []
class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, version: int) -> None:
self.version = version
super().__init__(app)
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
ordered_events.append(f"{self.version}:STARTED")
res = await call_next(request)
ordered_events.append(f"{self.version}:COMPLETED")
def background() -> None:
unordered_events.append(f"{self.version}:BACKGROUND")
assert res.background is None
res.background = BackgroundTask(background)
return res
async def sleepy(request: Request) -> Response:
try:
await request.body()
except ClientDisconnect:
pass
else: # pragma: no cover
raise AssertionError("Should have raised ClientDisconnect")
return Response(b"")
app = Starlette(
routes=[Route("/", sleepy)],
middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)],
)
scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}
async def receive() -> AsyncIterator[Message]:
yield {"type": "http.disconnect"}
sent: list[Message] = []
async def send(message: Message) -> None:
sent.append(message)
await app(scope, receive().__anext__, send)
assert ordered_events == [
"1:STARTED",
"2:STARTED",
"3:STARTED",
"4:STARTED",
"5:STARTED",
"6:STARTED",
"7:STARTED",
"8:STARTED",
"9:STARTED",
"10:STARTED",
"10:COMPLETED",
"9:COMPLETED",
"8:COMPLETED",
"7:COMPLETED",
"6:COMPLETED",
"5:COMPLETED",
"4:COMPLETED",
"3:COMPLETED",
"2:COMPLETED",
"1:COMPLETED",
]
assert sorted(unordered_events) == sorted(
[
"1:BACKGROUND",
"2:BACKGROUND",
"3:BACKGROUND",
"4:BACKGROUND",
"5:BACKGROUND",
"6:BACKGROUND",
"7:BACKGROUND",
"8:BACKGROUND",
"9:BACKGROUND",
"10:BACKGROUND",
]
)
assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"0")],
},
{"type": "http.response.body", "body": b"", "more_body": False},
]
@pytest.mark.anyio
@pytest.mark.parametrize("send_body", [True, False])
async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None:
for _ in range(2):
msg = await receive()
while msg["type"] == "http.request":
msg = await receive()
assert msg["type"] == "http.disconnect"
await Response(b"good!")(scope, receive, send)
class MyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = MyMiddleware(app_poll_disconnect)
scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}
async def receive() -> AsyncIterator[Message]:
# the key here is that we only ever send 1 htt.disconnect message
if send_body:
yield {"type": "http.request", "body": b"hello", "more_body": True}
yield {"type": "http.request", "body": b"", "more_body": False}
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover
sent: list[Message] = []
async def send(message: Message) -> None:
sent.append(message)
await app(scope, receive().__anext__, send)
assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"5")],
},
{"type": "http.response.body", "body": b"good!", "more_body": True},
{"type": "http.response.body", "body": b"", "more_body": False},
]
@pytest.mark.anyio
async def test_asgi_pathsend_events(tmpdir: Path) -> None:
path = tmpdir / "example.txt"
with path.open("w") as file:
file.write("")
response_complete = anyio.Event()
events: list[Message] = []
async def endpoint_with_pathsend(_: Request) -> FileResponse:
return FileResponse(path)
async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_pathsend)],
)
scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
"headers": [],
"extensions": {"http.response.pathsend": {}},
}
async def receive() -> Message:
raise NotImplementedError("Should not be called!") # pragma: no cover
async def send(message: Message) -> None:
events.append(message)
if message["type"] == "http.response.pathsend":
response_complete.set()
await app(scope, receive, send)
assert len(events) == 2
assert events[0]["type"] == "http.response.start"
assert events[1]["type"] == "http.response.pathsend"
def test_error_context_propagation(test_client_factory: TestClientFactory) -> None:
class PassthroughMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
return await call_next(request)
def exception_without_context(request: Request) -> None:
raise Exception("Exception")
def exception_with_context(request: Request) -> None:
try:
raise Exception("Inner exception")
except Exception:
raise Exception("Outer exception")
def exception_with_cause(request: Request) -> None:
try:
raise Exception("Inner exception")
except Exception as e:
raise Exception("Outer exception") from e
app = Starlette(
routes=[
Route("/exception-without-context", endpoint=exception_without_context),
Route("/exception-with-context", endpoint=exception_with_context),
Route("/exception-with-cause", endpoint=exception_with_cause),
],
middleware=[Middleware(PassthroughMiddleware)],
)
client = test_client_factory(app)
# For exceptions without context the context is filled with the `anyio.EndOfStream`
# but it is suppressed therefore not propagated to traceback.
with pytest.raises(Exception) as ctx:
client.get("/exception-without-context")
assert str(ctx.value) == "Exception"
assert ctx.value.__cause__ is None
assert ctx.value.__context__ is not None
assert ctx.value.__suppress_context__ is True
# For exceptions with context the context is propagated as a cause to avoid
# `anyio.EndOfStream` error from overwriting it.
with pytest.raises(Exception) as ctx:
client.get("/exception-with-context")
assert str(ctx.value) == "Outer exception"
assert ctx.value.__cause__ is not None
assert str(ctx.value.__cause__) == "Inner exception"
# For exceptions with cause check that it gets correctly propagated.
with pytest.raises(Exception) as ctx:
client.get("/exception-with-cause")
assert str(ctx.value) == "Outer exception"
assert ctx.value.__cause__ is not None
assert str(ctx.value.__cause__) == "Inner exception"
================================================
FILE: tests/middleware/test_cors.py
================================================
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from tests.types import TestClientFactory
def test_cors_allow_all(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_methods=["*"],
expose_headers=["X-Status"],
allow_credentials=True,
)
],
)
client = test_client_factory(app)
# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example",
}
response = client.options("/", headers=headers)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"
# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"
# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"
# Test non-CORS response
response = client.get("/")
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers
def test_cors_allow_all_except_credentials(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_methods=["*"],
expose_headers=["X-Status"],
)
],
)
client = test_client_factory(app)
# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example",
}
response = client.options("/", headers=headers)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert "access-control-allow-credentials" not in response.headers
assert "vary" not in response.headers
# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert "access-control-allow-credentials" not in response.headers
# Test non-CORS response
response = client.get("/")
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers
def test_cors_allow_specific_origin(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["https://example.org"],
allow_headers=["X-Example", "Content-Type"],
)
],
)
client = test_client_factory(app)
# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example, Content-Type",
}
response = client.options("/", headers=headers)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert "access-control-allow-credentials" not in response.headers
# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers
# Test non-CORS response
response = client.get("/")
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers
def test_cors_disallowed_preflight(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> None:
pass # pragma: no cover
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["https://example.org"],
allow_headers=["X-Example"],
)
],
)
client = test_client_factory(app)
# Test pre-flight response
headers = {
"Origin": "https://another.org",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "X-Nope",
}
response = client.options("/", headers=headers)
assert response.status_code == 400
assert response.text == "Disallowed CORS origin, method, headers"
assert "access-control-allow-origin" not in response.headers
# Bug specific test, https://github.com/Kludex/starlette/pull/1199
# Test preflight response text with multiple disallowed headers
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Nope-1, X-Nope-2",
}
response = client.options("/", headers=headers)
assert response.text == "Disallowed CORS headers"
def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> None:
return # pragma: no cover
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_credentials=True,
)
],
)
client = test_client_factory(app)
# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "POST",
}
response = client.options(
"/",
headers=headers,
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"
def test_cors_preflight_allow_all_methods(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> None:
pass # pragma: no cover
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])],
)
client = test_client_factory(app)
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "POST",
}
for method in ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"):
response = client.options("/", headers=headers)
assert response.status_code == 200
assert method in response.headers["access-control-allow-methods"]
def test_cors_allow_all_methods(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[
Route(
"/",
endpoint=homepage,
methods=["delete", "get", "head", "options", "patch", "post", "put"],
)
],
middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])],
)
client = test_client_factory(app)
headers = {"Origin": "https://example.org"}
for method in ("patch", "post", "put"):
response = getattr(client, method)("/", headers=headers, json={})
assert response.status_code == 200
for method in ("delete", "get", "head", "options"):
response = getattr(client, method)("/", headers=headers)
assert response.status_code == 200
def test_cors_allow_origin_regex(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_headers=["X-Example", "Content-Type"],
allow_origin_regex="https://.*",
allow_credentials=True,
)
],
)
client = test_client_factory(app)
# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"
# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"
# Test disallowed standard response
# Note that enforcement is a browser concern. The disallowed-ness is reflected
# in the lack of an "access-control-allow-origin" header in the response.
headers = {"Origin": "http://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers
# Test pre-flight response
headers = {
"Origin": "https://another.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example, content-type",
}
response = client.options("/", headers=headers)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "https://another.com"
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert response.headers["access-control-allow-credentials"] == "true"
# Test disallowed pre-flight response
headers = {
"Origin": "http://another.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example",
}
response = client.options("/", headers=headers)
assert response.status_code == 400
assert response.text == "Disallowed CORS origin"
assert "access-control-allow-origin" not in response.headers
def test_cors_allow_origin_regex_fullmatch(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_headers=["X-Example", "Content-Type"],
allow_origin_regex=r"https://.*\.example.org",
)
],
)
client = test_client_factory(app)
# Test standard response
headers = {"Origin": "https://subdomain.example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://subdomain.example.org"
assert "access-control-allow-credentials" not in response.headers
# Test disallowed standard response
headers = {"Origin": "https://subdomain.example.org.hacker.com"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers
def test_cors_vary_header_defaults_to_origin(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])],
)
headers = {"Origin": "https://example.org"}
client = test_client_factory(app)
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.headers["vary"] == "Origin"
def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(CORSMiddleware, allow_origins=["*"])],
)
client = test_client_factory(app)
response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.status_code == 200
assert response.headers["vary"] == "Accept-Encoding"
def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True)],
)
client = test_client_factory(app)
response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.status_code == 200
assert response.headers["vary"] == "Accept-Encoding, Origin"
def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[
Route("/", endpoint=homepage),
],
middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])],
)
client = test_client_factory(app)
response = client.get("/", headers={"Origin": "https://example.org"})
assert response.status_code == 200
assert response.headers["vary"] == "Accept-Encoding, Origin"
def test_cors_allowed_origin_does_not_leak_between_requests(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])],
)
client = test_client_factory(app)
response = client.get("/", headers={"Origin": "https://example.org"})
assert response.headers["access-control-allow-origin"] == "https://example.org"
response = client.get("/", headers={"Origin": "https://other.org"})
assert "access-control-allow-origin" not in response.headers
response = client.get("/", headers={"Origin": "https://example.org"})
assert response.headers["access-control-allow-origin"] == "https://example.org"
def test_cors_private_network_access_allowed(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_private_network=True,
)
],
)
client = test_client_factory(app)
headers_without_pna = {"Origin": "https://example.org", "Access-Control-Request-Method": "GET"}
headers_with_pna = {**headers_without_pna, "Access-Control-Request-Private-Network": "true"}
# Test preflight with Private Network Access request
response = client.options("/", headers=headers_with_pna)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-private-network"] == "true"
# Test preflight without Private Network Access request
response = client.options("/", headers=headers_without_pna)
assert response.status_code == 200
assert response.text == "OK"
assert "access-control-allow-private-network" not in response.headers
# The access-control-allow-private-network header is not set for non-preflight requests
response = client.get("/", headers=headers_with_pna)
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-private-network" not in response.headers
assert "access-control-allow-origin" in response.headers
def test_cors_private_network_access_disallowed(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> None: ... # pragma: no cover
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_private_network=False,
)
],
)
client = test_client_factory(app)
# Test preflight with Private Network Access request when not allowed
headers_without_pna = {"Origin": "https://example.org", "Access-Control-Request-Method": "GET"}
headers_with_pna = {**headers_without_pna, "Access-Control-Request-Private-Network": "true"}
response = client.options("/", headers=headers_without_pna)
assert response.status_code == 200
assert response.text == "OK"
assert "access-control-allow-private-network" not in response.headers
# If the request includes a Private Network Access header, but the middleware is configured to disallow it, the
# request should be denied with a 400 response.
response = client.options("/", headers=headers_with_pna)
assert response.status_code == 400
assert response.text == "Disallowed CORS private-network"
assert "access-control-allow-private-network" not in response.headers
================================================
FILE: tests/middleware/test_errors.py
================================================
from typing import Any
import pytest
from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
from starlette.types import Receive, Scope, Send
from tests.types import TestClientFactory
def test_handler(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
def error_500(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse({"detail": "Server Error"}, status_code=500)
app = ServerErrorMiddleware(app, handler=error_500)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.json() == {"detail": "Server Error"}
def test_debug_text(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.headers["content-type"].startswith("text/plain")
assert "RuntimeError: Something went wrong" in response.text
def test_debug_html(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/", headers={"Accept": "text/html, */*"})
assert response.status_code == 500
assert response.headers["content-type"].startswith("text/html")
assert "RuntimeError" in response.text
def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response(b"", status_code=204)
await response(scope, receive, send)
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
client = test_client_factory(app)
with pytest.raises(RuntimeError):
client.get("/")
def test_debug_not_http(test_client_factory: TestClientFactory) -> None:
"""
DebugMiddleware should just pass through any non-http messages as-is.
"""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app)
with pytest.raises(RuntimeError):
client = test_client_factory(app)
with client.websocket_connect("/"):
pass # pragma: no cover
def test_background_task(test_client_factory: TestClientFactory) -> None:
accessed_error_handler = False
def error_handler(request: Request, exc: Exception) -> Any:
nonlocal accessed_error_handler
accessed_error_handler = True
def raise_exception() -> None:
raise Exception("Something went wrong")
async def endpoint(request: Request) -> Response:
task = BackgroundTask(raise_exception)
return Response(status_code=204, background=task)
app = Starlette(
routes=[Route("/", endpoint=endpoint)],
exception_handlers={Exception: error_handler},
)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 204
assert accessed_error_handler
================================================
FILE: tests/middleware/test_gzip.py
================================================
from __future__ import annotations
from pathlib import Path
import pytest
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import Request
from starlette.responses import ContentStream, FileResponse, PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.types import Message
from tests.types import TestClientFactory
def test_gzip_responses(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("x" * 4000, status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(GZipMiddleware)],
)
client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "gzip"
assert response.headers["Vary"] == "Accept-Encoding"
assert int(response.headers["Content-Length"]) < 4000
def test_gzip_not_in_accept_encoding(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("x" * 4000, status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(GZipMiddleware)],
)
client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "identity"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert response.headers["Vary"] == "Accept-Encoding"
assert int(response.headers["Content-Length"]) == 4000
def test_gzip_ignored_for_small_responses(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("OK", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(GZipMiddleware)],
)
client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "OK"
assert "Content-Encoding" not in response.headers
assert "Vary" not in response.headers
assert int(response.headers["Content-Length"]) == 2
def test_gzip_streaming_response(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for index in range(count):
yield bytes
streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(GZipMiddleware)],
)
client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "gzip"
assert response.headers["Vary"] == "Accept-Encoding"
assert "Content-Length" not in response.headers
def test_gzip_streaming_response_identity(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for index in range(count):
yield bytes
streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(GZipMiddleware)],
)
client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "identity"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert response.headers["Vary"] == "Accept-Encoding"
assert "Content-Length" not in response.headers
def test_gzip_ignored_for_responses_with_encoding_set(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for index in range(count):
yield bytes
streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200, headers={"Content-Encoding": "text"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(GZipMiddleware)],
)
client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip, text"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "text"
assert "Vary" not in response.headers
assert "Content-Length" not in response.headers
def test_gzip_ignored_on_server_sent_events(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> StreamingResponse:
async def generator(bytes: bytes, count: int) -> ContentStream:
for _ in range(count):
yield bytes
streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200, media_type="text/event-stream")
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(GZipMiddleware)],
)
client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert "Content-Length" not in response.headers
@pytest.mark.anyio
async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None:
path = tmpdir / "example.txt"
with path.open("w") as file:
file.write("")
events: list[Message] = []
async def endpoint_with_pathsend(request: Request) -> FileResponse:
_ = await request.body()
return FileResponse(path)
app = Starlette(
routes=[Route("/", endpoint=endpoint_with_pathsend)],
middleware=[Middleware(GZipMiddleware)],
)
scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
"headers": [(b"accept-encoding", b"gzip, text")],
"extensions": {"http.response.pathsend": {}},
}
async def receive() -> Message:
return {"type": "http.request", "body": b"", "more_body": False}
async def send(message: Message) -> None:
events.append(message)
await app(scope, receive, send)
assert len(events) == 2
assert events[0]["type"] == "http.response.start"
assert events[1]["type"] == "http.response.pathsend"
================================================
FILE: tests/middleware/test_https_redirect.py
================================================
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from tests.types import TestClientFactory
def test_https_redirect_middleware(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("OK", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(HTTPSRedirectMiddleware)],
)
client = test_client_factory(app, base_url="https://testserver")
response = client.get("/")
assert response.status_code == 200
client = test_client_factory(app)
response = client.get("/", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://testserver/"
client = test_client_factory(app, base_url="http://testserver:80")
response = client.get("/", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://testserver/"
client = test_client_factory(app, base_url="http://testserver:443")
response = client.get("/", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://testserver/"
client = test_client_factory(app, base_url="http://testserver:123")
response = client.get("/", follow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://testserver:123/"
================================================
FILE: tests/middleware/test_middleware.py
================================================
from starlette.middleware import Middleware
from starlette.types import ASGIApp, Receive, Scope, Send
class CustomMiddleware: # pragma: no cover
def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None:
self.app = app
self.foo = foo
self.bar = bar
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def test_middleware_repr() -> None:
middleware = Middleware(CustomMiddleware, "foo", bar=123)
assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)"
def test_middleware_iter() -> None:
cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123)
assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123})
================================================
FILE: tests/middleware/test_session.py
================================================
import re
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.sessions import Session, SessionMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
from starlette.testclient import TestClient
from tests.types import TestClientFactory
def view_session(request: Request) -> JSONResponse:
return JSONResponse({"session": request.session})
async def update_session(request: Request) -> JSONResponse:
data = await request.json()
request.session.update(data)
return JSONResponse({"session": request.session})
async def clear_session(request: Request) -> JSONResponse:
request.session.clear()
return JSONResponse({"session": request.session})
def no_session_access(request: Request) -> JSONResponse:
return JSONResponse({"status": "ok"})
def test_session(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
Route("/clear_session", endpoint=clear_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example")],
)
client = test_client_factory(app)
response = client.get("/view_session")
assert response.json() == {"session": {}}
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
# check cookie max-age
set_cookie = response.headers["set-cookie"]
max_age_matches = re.search(r"; Max-Age=([0-9]+);", set_cookie)
assert max_age_matches is not None
assert int(max_age_matches[1]) == 14 * 24 * 3600
response = client.get("/view_session")
assert response.json() == {"session": {"some": "data"}}
response = client.post("/clear_session")
assert response.json() == {"session": {}}
response = client.get("/view_session")
assert response.json() == {"session": {}}
def test_session_expires(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=-1)],
)
client = test_client_factory(app)
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
# requests removes expired cookies from response.cookies, we need to
# fetch session id from the headers and pass it explicitly
expired_cookie_header = response.headers["set-cookie"]
expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header)
assert expired_session_match is not None
expired_session_value = expired_session_match[1]
client = test_client_factory(app, cookies={"session": expired_session_value})
response = client.get("/view_session")
assert response.json() == {"session": {}}
def test_secure_session(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
Route("/clear_session", endpoint=clear_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example", https_only=True)],
)
secure_client = test_client_factory(app, base_url="https://testserver")
unsecure_client = test_client_factory(app, base_url="http://testserver")
response = unsecure_client.get("/view_session")
assert response.json() == {"session": {}}
response = unsecure_client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
response = unsecure_client.get("/view_session")
assert response.json() == {"session": {}}
response = secure_client.get("/view_session")
assert response.json() == {"session": {}}
response = secure_client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
response = secure_client.get("/view_session")
assert response.json() == {"session": {"some": "data"}}
response = secure_client.post("/clear_session")
assert response.json() == {"session": {}}
response = secure_client.get("/view_session")
assert response.json() == {"session": {}}
def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None:
second_app = Starlette(
routes=[
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example", path="/second_app")],
)
app = Starlette(routes=[Mount("/second_app", app=second_app)])
client = test_client_factory(app, base_url="http://testserver/second_app")
response = client.post("/update_session", json={"some": "data"})
assert response.status_code == 200
cookie = response.headers["set-cookie"]
cookie_path_match = re.search(r"; path=(\S+);", cookie)
assert cookie_path_match is not None
cookie_path = cookie_path_match.groups()[0]
assert cookie_path == "/second_app"
def test_invalid_session_cookie(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example")],
)
client = test_client_factory(app)
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
# we expect it to not raise an exception if we provide a bogus session cookie
client = test_client_factory(app, cookies={"session": "invalid"})
response = client.get("/view_session")
assert response.json() == {"session": {}}
def test_session_cookie(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=None)],
)
client: TestClient = test_client_factory(app)
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
# check cookie max-age
set_cookie = response.headers["set-cookie"]
assert "Max-Age" not in set_cookie
client.cookies.delete("session")
response = client.get("/view_session")
assert response.json() == {"session": {}}
def test_domain_cookie(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example", domain=".example.com")],
)
client: TestClient = test_client_factory(app)
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
# check cookie max-age
set_cookie = response.headers["set-cookie"]
assert "domain=.example.com" in set_cookie
client.cookies.delete("session")
response = client.get("/view_session")
assert response.json() == {"session": {}}
def test_set_cookie_only_on_modification(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example")],
)
client = test_client_factory(app)
# Write to session - should send Set-Cookie
response = client.post("/update_session", json={"some": "data"})
assert "set-cookie" in response.headers
# Read-only access - should NOT send Set-Cookie
response = client.get("/view_session")
assert response.json() == {"session": {"some": "data"}}
assert "set-cookie" not in response.headers
def test_vary_cookie_on_access(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
Route("/no_session", endpoint=no_session_access),
],
middleware=[Middleware(SessionMiddleware, secret_key="example")],
)
client = test_client_factory(app)
# Modifying session should add Vary: Cookie
response = client.post("/update_session", json={"some": "data"})
assert "cookie" in response.headers.get("vary", "").lower()
# Reading a non-empty session should add Vary: Cookie
response = client.get("/view_session")
assert "cookie" in response.headers.get("vary", "").lower()
# Not accessing session at all should NOT add Vary: Cookie
response = client.get("/no_session")
assert "cookie" not in response.headers.get("vary", "").lower()
def test_session_tracks_modification() -> None:
session = Session({"a": "1", "b": "2"})
assert not session.modified
# __setitem__
session["c"] = "3"
assert session.modified
# __delitem__
session = Session({"a": "1"})
del session["a"]
assert session.modified
# clear
session = Session({"a": "1"})
session.clear()
assert session.modified
# pop with existing key
session = Session({"a": "1"})
session.pop("a")
assert session.modified
# pop with missing key
session = Session({"a": "1"})
session.pop("missing", None)
assert not session.modified
# setdefault with missing key
session = Session({"a": "1"})
session.setdefault("b", "2")
assert session.modified
# setdefault with existing key
session = Session({"a": "1"})
session.setdefault("a", "2")
assert not session.modified
================================================
FILE: tests/middleware/test_trusted_host.py
================================================
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from tests.types import TestClientFactory
def test_trusted_host_middleware(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("OK", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"])],
)
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
client = test_client_factory(app, base_url="http://subdomain.testserver")
response = client.get("/")
assert response.status_code == 200
client = test_client_factory(app, base_url="http://invalidhost")
response = client.get("/")
assert response.status_code == 400
def test_default_allowed_hosts() -> None:
app = Starlette()
middleware = TrustedHostMiddleware(app)
assert middleware.allowed_hosts == ["*"]
def test_www_redirect(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("OK", status_code=200)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])],
)
client = test_client_factory(app, base_url="https://example.com")
response = client.get("/")
assert response.status_code == 200
assert response.url == "https://www.example.com/"
================================================
FILE: tests/middleware/test_wsgi.py
================================================
import sys
from collections.abc import Callable, Iterable
from typing import Any
import pytest
from starlette._utils import collapse_excgroups
from starlette.middleware.wsgi import WSGIMiddleware, build_environ
from tests.types import TestClientFactory
WSGIResponse = Iterable[bytes]
StartResponse = Callable[..., Any]
Environment = dict[str, Any]
def hello_world(
environ: Environment,
start_response: StartResponse,
) -> WSGIResponse:
status = "200 OK"
output = b"Hello World!\n"
headers = [
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers)
return [output]
def echo_body(
environ: Environment,
start_response: StartResponse,
) -> WSGIResponse:
status = "200 OK"
output = environ["wsgi.input"].read()
headers = [
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers)
return [output]
def raise_exception(
environ: Environment,
start_response: StartResponse,
) -> WSGIResponse:
raise RuntimeError("Something went wrong")
def return_exc_info(
environ: Environment,
start_response: StartResponse,
) -> WSGIResponse:
try:
raise RuntimeError("Something went wrong")
except RuntimeError:
status = "500 Internal Server Error"
output = b"Internal Server Error"
headers = [
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers, exc_info=sys.exc_info())
return [output]
def test_wsgi_get(test_client_factory: TestClientFactory) -> None:
app = WSGIMiddleware(hello_world)
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello World!\n"
def test_wsgi_post(test_client_factory: TestClientFactory) -> None:
app = WSGIMiddleware(echo_body)
client = test_client_factory(app)
response = client.post("/", json={"example": 123})
assert response.status_code == 200
assert response.text == '{"example":123}'
def test_wsgi_exception(test_client_factory: TestClientFactory) -> None:
# Note that we're testing the WSGI app directly here.
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
with pytest.raises(RuntimeError), collapse_excgroups():
client.get("/")
def test_wsgi_exc_info(test_client_factory: TestClientFactory) -> None:
# Note that we're testing the WSGI app directly here.
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(return_exc_info)
client = test_client_factory(app)
with pytest.raises(RuntimeError):
response = client.get("/")
app = WSGIMiddleware(return_exc_info)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.text == "Internal Server Error"
def test_build_environ() -> None:
scope = {
"type": "http",
"http_version": "1.1",
"method": "GET",
"scheme": "https",
"path": "/sub/",
"root_path": "/sub",
"query_string": b"a=123&b=456",
"headers": [
(b"host", b"www.example.org"),
(b"content-type", b"application/json"),
(b"content-length", b"18"),
(b"accept", b"application/json"),
(b"accept", b"text/plain"),
],
"client": ("134.56.78.4", 1453),
"server": ("www.example.org", 443),
}
body = b'{"example":"body"}'
environ = build_environ(scope, body)
stream = environ.pop("wsgi.input")
assert stream.read() == b'{"example":"body"}'
assert environ == {
"CONTENT_LENGTH": "18",
"CONTENT_TYPE": "application/json",
"HTTP_ACCEPT": "application/json,text/plain",
"HTTP_HOST": "www.example.org",
"PATH_INFO": "/",
"QUERY_STRING": "a=123&b=456",
"REMOTE_ADDR": "134.56.78.4",
"REQUEST_METHOD": "GET",
"SCRIPT_NAME": "/sub",
"SERVER_NAME": "www.example.org",
"SERVER_PORT": 443,
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stdout,
"wsgi.multiprocess": True,
"wsgi.multithread": True,
"wsgi.run_once": False,
"wsgi.url_scheme": "https",
"wsgi.version": (1, 0),
}
def test_build_environ_encoding() -> None:
scope = {
"type": "http",
"http_version": "1.1",
"method": "GET",
"path": "/小星",
"root_path": "/中国",
"query_string": b"a=123&b=456",
"headers": [],
}
environ = build_environ(scope, b"")
assert environ["SCRIPT_NAME"] == "/中国".encode().decode("latin-1")
assert environ["PATH_INFO"] == "/小星".encode().decode("latin-1")
================================================
FILE: tests/statics/example.txt
================================================
123
================================================
FILE: tests/test__utils.py
================================================
import functools
from typing import Any
from unittest.mock import create_autospec
import pytest
from starlette._utils import get_route_path, is_async_callable
from starlette.types import Scope
def test_async_func() -> None:
async def async_func() -> None: ... # pragma: no cover
def func() -> None: ... # pragma: no cover
assert is_async_callable(async_func)
assert not is_async_callable(func)
def test_async_partial() -> None:
async def async_func(a: Any, b: Any) -> None: ... # pragma: no cover
def func(a: Any, b: Any) -> None: ... # pragma: no cover
partial = functools.partial(async_func, 1)
assert is_async_callable(partial)
partial = functools.partial(func, 1) # type: ignore
assert not is_async_callable(partial)
def test_async_method() -> None:
class Async:
async def method(self) -> None: ... # pragma: no cover
class Sync:
def method(self) -> None: ... # pragma: no cover
assert is_async_callable(Async().method)
assert not is_async_callable(Sync().method)
def test_async_object_call() -> None:
class Async:
async def __call__(self) -> None: ... # pragma: no cover
class Sync:
def __call__(self) -> None: ... # pragma: no cover
assert is_async_callable(Async())
assert not is_async_callable(Sync())
def test_async_partial_object_call() -> None:
class Async:
async def __call__(
self,
a: Any,
b: Any,
) -> None: ... # pragma: no cover
class Sync:
def __call__(
self,
a: Any,
b: Any,
) -> None: ... # pragma: no cover
partial = functools.partial(Async(), 1)
assert is_async_callable(partial)
partial = functools.partial(Sync(), 1) # type: ignore
assert not is_async_callable(partial)
def test_async_nested_partial() -> None:
async def async_func(
a: Any,
b: Any,
) -> None: ... # pragma: no cover
partial = functools.partial(async_func, b=2)
nested_partial = functools.partial(partial, a=1)
assert is_async_callable(nested_partial)
def test_async_mocked_async_function() -> None:
async def async_func() -> None: ... # pragma: no cover
mock = create_autospec(async_func)
assert is_async_callable(mock)
@pytest.mark.parametrize(
"scope, expected_result",
[
({"path": "/foo-123/bar", "root_path": "/foo"}, "/foo-123/bar"),
({"path": "/foo/bar", "root_path": "/foo"}, "/bar"),
({"path": "/foo", "root_path": "/foo"}, ""),
({"path": "/foo/bar", "root_path": "/bar"}, "/foo/bar"),
],
)
def test_get_route_path(scope: Scope, expected_result: str) -> None:
assert get_route_path(scope) == expected_result
================================================
FILE: tests/test_applications.py
================================================
from __future__ import annotations
import os
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TypedDict
import anyio.from_thread
import pytest
from starlette import status
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient, WebSocketDenialResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket
from tests.types import TestClientFactory
async def error_500(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": "Server Error"}, status_code=500)
async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": "Custom message"}, status_code=405)
async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
def func_homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!")
async def async_homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!")
class Homepage(HTTPEndpoint):
def get(self, request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!")
def all_users_page(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, everyone!")
def user_page(request: Request) -> PlainTextResponse:
username = request.path_params["username"]
return PlainTextResponse(f"Hello, {username}!")
def custom_subdomain(request: Request) -> PlainTextResponse:
return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])
def runtime_error(request: Request) -> None:
raise RuntimeError()
async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
async def websocket_raise_websocket_exception(websocket: WebSocket) -> None:
await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
async def websocket_raise_http_exception(websocket: WebSocket) -> None:
raise HTTPException(status_code=401, detail="Unauthorized")
class CustomWSException(Exception):
pass
async def websocket_raise_custom(websocket: WebSocket) -> None:
await websocket.accept()
raise CustomWSException()
async def websocket_state(websocket: WebSocket[CustomState]) -> None:
await websocket.accept()
await websocket.send_json({"count": websocket.state["count"]})
await websocket.close()
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
class CustomState(TypedDict):
count: int
@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[CustomState]:
yield {"count": 1}
async def state_count(request: Request[CustomState]) -> JSONResponse:
return JSONResponse({"count": request.state["count"]}, status_code=200)
users = Router(
routes=[
Route("/", endpoint=all_users_page),
Route("/{username}", endpoint=user_page),
]
)
subdomain = Router(
routes=[
Route("/", custom_subdomain),
]
)
exception_handlers = {
500: error_500,
405: method_not_allowed,
HTTPException: http_exception,
CustomWSException: custom_ws_exception_handler,
}
middleware = [Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])]
app = Starlette(
routes=[
Route("/func", endpoint=func_homepage),
Route("/async", endpoint=async_homepage),
Route("/class", endpoint=Homepage),
Route("/state", endpoint=state_count),
Route("/500", endpoint=runtime_error),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
WebSocketRoute("/ws-state", endpoint=websocket_state),
Mount("/users", app=users),
Host("{subdomain}.example.org", app=subdomain),
],
exception_handlers=exception_handlers, # type: ignore
middleware=middleware,
lifespan=lifespan,
)
@pytest.fixture
def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client
def test_url_path_for() -> None:
assert app.url_path_for("func_homepage") == "/func"
def test_func_route(client: TestClient) -> None:
response = client.get("/func")
assert response.status_code == 200
assert response.text == "Hello, world!"
response = client.head("/func")
assert response.status_code == 200
assert response.text == ""
def test_async_route(client: TestClient) -> None:
response = client.get("/async")
assert response.status_code == 200
assert response.text == "Hello, world!"
def test_class_route(client: TestClient) -> None:
response = client.get("/class")
assert response.status_code == 200
assert response.text == "Hello, world!"
def test_mounted_route(client: TestClient) -> None:
response = client.get("/users/")
assert response.status_code == 200
assert response.text == "Hello, everyone!"
def test_mounted_route_path_params(client: TestClient) -> None:
response = client.get("/users/tomchristie")
assert response.status_code == 200
assert response.text == "Hello, tomchristie!"
def test_subdomain_route(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, base_url="https://foo.example.org/")
response = client.get("/")
assert response.status_code == 200
assert response.text == "Subdomain: foo"
def test_websocket_route(client: TestClient) -> None:
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
def test_400(client: TestClient) -> None:
response = client.get("/404")
assert response.status_code == 404
assert response.json() == {"detail": "Not Found"}
def test_405(client: TestClient) -> None:
response = client.post("/func")
assert response.status_code == 405
assert response.json() == {"detail": "Custom message"}
response = client.post("/class")
assert response.status_code == 405
assert response.json() == {"detail": "Custom message"}
def test_500(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/500")
assert response.status_code == 500
assert response.json() == {"detail": "Server Error"}
def test_request_state(client: TestClient) -> None:
response = client.get("/state")
assert response.status_code == 200
assert response.json() == {"count": 1}
def test_websocket_raise_websocket_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1003_UNSUPPORTED_DATA,
"reason": "",
}
def test_websocket_state(client: TestClient) -> None:
with client.websocket_connect("/ws-state") as session:
response = session.receive_json()
assert response == {"count": 1}
def test_websocket_raise_http_exception(client: TestClient) -> None:
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/ws-raise-http"):
pass # pragma: no cover
assert exc.value.status_code == 401
assert exc.value.content == b'{"detail":"Unauthorized"}'
def test_websocket_raise_custom_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1013_TRY_AGAIN_LATER,
"reason": "",
}
def test_middleware(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, base_url="http://incorrecthost")
response = client.get("/func")
assert response.status_code == 400
assert response.text == "Invalid host header"
def test_routes() -> None:
assert app.routes == [
Route("/func", endpoint=func_homepage, methods=["GET"]),
Route("/async", endpoint=async_homepage, methods=["GET"]),
Route("/class", endpoint=Homepage),
Route("/state", endpoint=state_count, methods=["GET"]),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
WebSocketRoute("/ws-state", endpoint=websocket_state),
Mount(
"/users",
app=Router(
routes=[
Route("/", endpoint=all_users_page),
Route("/{username}", endpoint=user_page),
]
),
),
Host(
"{subdomain}.example.org",
app=Router(routes=[Route("/", endpoint=custom_subdomain)]),
),
]
def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = Starlette(
routes=[
Mount("/static", StaticFiles(directory=tmpdir)),
]
)
client = test_client_factory(app)
response = client.get("/static/example.txt")
assert response.status_code == 200
assert response.text == ""
response = client.post("/static/example.txt")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
def test_app_debug(test_client_factory: TestClientFactory) -> None:
async def homepage(request: Request) -> None:
raise RuntimeError()
app = Starlette(
routes=[
Route("/", homepage),
],
)
app.debug = True
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert "RuntimeError" in response.text
assert app.debug
def test_app_add_route(test_client_factory: TestClientFactory) -> None:
async def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, World!")
app = Starlette(
routes=[
Route("/", endpoint=homepage),
]
)
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, World!"
def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None:
async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
app = Starlette(
routes=[
WebSocketRoute("/ws", endpoint=websocket_endpoint),
]
)
client = test_client_factory(app)
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False
cleanup_complete = False
@asynccontextmanager
async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True
app = Starlette(lifespan=lifespan)
assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
deprecated_lifespan = pytest.mark.filterwarnings(
r"ignore"
r":(async )?generator function lifespans are deprecated, use an "
r"@contextlib\.asynccontextmanager function instead"
r":DeprecationWarning"
r":starlette.routing"
)
@deprecated_lifespan
def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False
cleanup_complete = False
async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True
app = Starlette(lifespan=lifespan) # type: ignore
assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
@deprecated_lifespan
def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False
cleanup_complete = False
def lifespan(app: ASGIApp) -> Generator[None, None, None]:
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True
app = Starlette(lifespan=lifespan) # type: ignore
assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None:
class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
class SimpleInitializableMiddleware:
counter = 0
def __init__(self, app: ASGIApp):
self.app = app
SimpleInitializableMiddleware.counter += 1
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
def get_app() -> ASGIApp:
app = Starlette()
app.add_middleware(SimpleInitializableMiddleware)
app.add_middleware(NoOpMiddleware)
return app
app = get_app()
with test_client_factory(app):
pass
assert SimpleInitializableMiddleware.counter == 1
test_client_factory(app).get("/foo")
assert SimpleInitializableMiddleware.counter == 1
app = get_app()
test_client_factory(app).get("/foo")
assert SimpleInitializableMiddleware.counter == 2
def test_middleware_args(test_client_factory: TestClientFactory) -> None:
calls: list[str] = []
class MiddlewareWithArgs:
def __init__(self, app: ASGIApp, arg: str) -> None:
self.app = app
self.arg = arg
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
calls.append(self.arg)
await self.app(scope, receive, send)
app = Starlette()
app.add_middleware(MiddlewareWithArgs, "foo")
app.add_middleware(MiddlewareWithArgs, "bar")
with test_client_factory(app):
pass
assert calls == ["bar", "foo"]
def test_middleware_factory(test_client_factory: TestClientFactory) -> None:
calls: list[str] = []
def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp:
async def _app(scope: Scope, receive: Receive, send: Send) -> None:
calls.append(arg)
await app(scope, receive, send)
return _app
def get_middleware_factory() -> Callable[[ASGIApp, str], ASGIApp]:
return _middleware_factory
app = Starlette()
app.add_middleware(_middleware_factory, arg="foo")
app.add_middleware(get_middleware_factory(), "bar")
with test_client_factory(app):
pass
assert calls == ["bar", "foo"]
def test_lifespan_app_subclass() -> None:
# This test exists to make sure that subclasses of Starlette
# (like FastAPI) are compatible with the types hints for Lifespan
class App(Starlette):
pass
@asynccontextmanager
async def lifespan(app: App) -> AsyncIterator[None]: # pragma: no cover
yield
App(lifespan=lifespan)
================================================
FILE: tests/test_authentication.py
================================================
from __future__ import annotations
import base64
import binascii
from collections.abc import Awaitable, Callable
from typing import Any
from urllib.parse import urlencode
import pytest
from starlette.applications import Starlette
from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, SimpleUser, requires
from starlette.endpoints import HTTPEndpoint
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route, WebSocketRoute
from starlette.websockets import WebSocket, WebSocketDisconnect
from tests.types import TestClientFactory
AsyncEndpoint = Callable[..., Awaitable[Response]]
SyncEndpoint = Callable[..., Response]
class BasicAuth(AuthenticationBackend):
async def authenticate(
self,
request: HTTPConnection,
) -> tuple[AuthCredentials, SimpleUser] | None:
if "Authorization" not in request.headers:
return None
auth = request.headers["Authorization"]
try:
scheme, credentials = auth.split()
decoded = base64.b64decode(credentials).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise AuthenticationError("Invalid basic auth credentials")
username, _, password = decoded.partition(":")
return AuthCredentials(["authenticated"]), SimpleUser(username)
def homepage(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated")
async def dashboard(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated", redirect="homepage")
async def admin(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated")
def dashboard_sync(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
class Dashboard(HTTPEndpoint):
@requires("authenticated")
def get(self, request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated", redirect="homepage")
def admin_sync(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated")
async def websocket_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_json(
{
"authenticated": websocket.user.is_authenticated,
"user": websocket.user.display_name,
}
)
def async_inject_decorator(
**kwargs: Any,
) -> Callable[[AsyncEndpoint], Callable[..., Awaitable[Response]]]:
def wrapper(endpoint: AsyncEndpoint) -> Callable[..., Awaitable[Response]]:
async def app(request: Request) -> Response:
return await endpoint(request=request, **kwargs)
return app
return wrapper
@async_inject_decorator(additional="payload")
@requires("authenticated")
async def decorated_async(request: Request, additional: str) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
"additional": additional,
}
)
def sync_inject_decorator(
**kwargs: Any,
) -> Callable[[SyncEndpoint], Callable[..., Response]]:
def wrapper(endpoint: SyncEndpoint) -> Callable[..., Response]:
def app(request: Request) -> Response:
return endpoint(request=request, **kwargs)
return app
return wrapper
@sync_inject_decorator(additional="payload")
@requires("authenticated")
def decorated_sync(request: Request, additional: str) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
"additional": additional,
}
)
def ws_inject_decorator(**kwargs: Any) -> Callable[..., AsyncEndpoint]:
def wrapper(endpoint: AsyncEndpoint) -> AsyncEndpoint:
def app(websocket: WebSocket) -> Awaitable[Response]:
return endpoint(websocket=websocket, **kwargs)
return app
return wrapper
@ws_inject_decorator(additional="payload")
@requires("authenticated")
async def websocket_endpoint_decorated(websocket: WebSocket, additional: str) -> None:
await websocket.accept()
await websocket.send_json(
{
"authenticated": websocket.user.is_authenticated,
"user": websocket.user.display_name,
"additional": additional,
}
)
app = Starlette(
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth())],
routes=[
Route("/", endpoint=homepage),
Route("/dashboard", endpoint=dashboard),
Route("/admin", endpoint=admin),
Route("/dashboard/sync", endpoint=dashboard_sync),
Route("/dashboard/class", endpoint=Dashboard),
Route("/admin/sync", endpoint=admin_sync),
Route("/dashboard/decorated", endpoint=decorated_async),
Route("/dashboard/decorated/sync", endpoint=decorated_sync),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws/decorated", endpoint=websocket_endpoint_decorated),
],
)
def test_invalid_decorator_usage() -> None:
with pytest.raises(Exception):
@requires("authenticated")
def foo() -> None: # pragma: no cover
pass
def test_user_interface(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"authenticated": False, "user": ""}
response = client.get("/", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
def test_authentication_required(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/dashboard")
assert response.status_code == 403
response = client.get("/dashboard", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/dashboard/sync")
assert response.status_code == 403
response = client.get("/dashboard/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/dashboard/class")
assert response.status_code == 403
response = client.get("/dashboard/class", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/dashboard/decorated", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {
"authenticated": True,
"user": "tomchristie",
"additional": "payload",
}
response = client.get("/dashboard/decorated")
assert response.status_code == 403
response = client.get("/dashboard/decorated/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {
"authenticated": True,
"user": "tomchristie",
"additional": "payload",
}
response = client.get("/dashboard/decorated/sync")
assert response.status_code == 403
response = client.get("/dashboard", headers={"Authorization": "basic foobar"})
assert response.status_code == 400
assert response.text == "Invalid basic auth credentials"
def test_websocket_authentication_required(
test_client_factory: TestClientFactory,
) -> None:
with test_client_factory(app) as client:
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws"):
pass # pragma: no cover
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}):
pass # pragma: no cover
with client.websocket_connect("/ws", auth=("tomchristie", "example")) as websocket:
data = websocket.receive_json()
assert data == {"authenticated": True, "user": "tomchristie"}
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws/decorated"):
pass # pragma: no cover
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws/decorated", headers={"Authorization": "basic foobar"}):
pass # pragma: no cover
with client.websocket_connect("/ws/decorated", auth=("tomchristie", "example")) as websocket:
data = websocket.receive_json()
assert data == {
"authenticated": True,
"user": "tomchristie",
"additional": "payload",
}
def test_authentication_redirect(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/admin")
assert response.status_code == 200
url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin"}))
assert response.url == url
response = client.get("/admin", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/admin/sync")
assert response.status_code == 200
url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin/sync"}))
assert response.url == url
response = client.get("/admin/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse:
return JSONResponse({"error": str(exc)}, status_code=401)
@requires("authenticated")
def control_panel(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
other_app = Starlette(
routes=[Route("/control-panel", control_panel)],
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error)],
)
def test_custom_on_error(test_client_factory: TestClientFactory) -> None:
with test_client_factory(other_app) as client:
response = client.get("/control-panel", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/control-panel", headers={"Authorization": "basic foobar"})
assert response.status_code == 401
assert response.json() == {"error": "Invalid basic auth credentials"}
================================================
FILE: tests/test_background.py
================================================
import pytest
from starlette.background import BackgroundTask, BackgroundTasks
from starlette.responses import Response
from starlette.types import Receive, Scope, Send
from tests.types import TestClientFactory
def test_async_task(test_client_factory: TestClientFactory) -> None:
TASK_COMPLETE = False
async def async_task() -> None:
nonlocal TASK_COMPLETE
TASK_COMPLETE = True
task = BackgroundTask(async_task)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("task initiated", media_type="text/plain", background=task)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "task initiated"
assert TASK_COMPLETE
def test_sync_task(test_client_factory: TestClientFactory) -> None:
TASK_COMPLETE = False
def sync_task() -> None:
nonlocal TASK_COMPLETE
TASK_COMPLETE = True
task = BackgroundTask(sync_task)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("task initiated", media_type="text/plain", background=task)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "task initiated"
assert TASK_COMPLETE
def test_multiple_tasks(test_client_factory: TestClientFactory) -> None:
TASK_COUNTER = 0
def increment(amount: int) -> None:
nonlocal TASK_COUNTER
TASK_COUNTER += amount
async def app(scope: Scope, receive: Receive, send: Send) -> None:
tasks = BackgroundTasks()
tasks.add_task(increment, amount=1)
tasks.add_task(increment, amount=2)
tasks.add_task(increment, amount=3)
response = Response("tasks initiated", media_type="text/plain", background=tasks)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "tasks initiated"
assert TASK_COUNTER == 1 + 2 + 3
def test_multi_tasks_failure_avoids_next_execution(
test_client_factory: TestClientFactory,
) -> None:
TASK_COUNTER = 0
def increment() -> None:
nonlocal TASK_COUNTER
TASK_COUNTER += 1
if TASK_COUNTER == 1: # pragma: no branch
raise Exception("task failed")
async def app(scope: Scope, receive: Receive, send: Send) -> None:
tasks = BackgroundTasks()
tasks.add_task(increment)
tasks.add_task(increment)
response = Response("tasks initiated", media_type="text/plain", background=tasks)
await response(scope, receive, send)
client = test_client_factory(app)
with pytest.raises(Exception):
client.get("/")
assert TASK_COUNTER == 1
================================================
FILE: tests/test_concurrency.py
================================================
from collections.abc import Iterator
from contextvars import ContextVar
import anyio
import pytest
from starlette.applications import Starlette
from starlette.concurrency import iterate_in_threadpool, run_until_first_complete
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
from tests.types import TestClientFactory
@pytest.mark.anyio
async def test_run_until_first_complete() -> None:
task1_finished = anyio.Event()
task2_finished = anyio.Event()
async def task1() -> None:
task1_finished.set()
async def task2() -> None:
await task1_finished.wait()
await anyio.sleep(0) # pragma: no cover
task2_finished.set() # pragma: no cover
await run_until_first_complete((task1, {}), (task2, {}))
assert task1_finished.is_set()
assert not task2_finished.is_set()
def test_accessing_context_from_threaded_sync_endpoint(
test_client_factory: TestClientFactory,
) -> None:
ctxvar: ContextVar[bytes] = ContextVar("ctxvar")
ctxvar.set(b"data")
def endpoint(request: Request) -> Response:
return Response(ctxvar.get())
app = Starlette(routes=[Route("/", endpoint)])
client = test_client_factory(app)
resp = client.get("/")
assert resp.content == b"data"
@pytest.mark.anyio
async def test_iterate_in_threadpool() -> None:
class CustomIterable:
def __iter__(self) -> Iterator[int]:
yield from range(3)
assert [v async for v in iterate_in_threadpool(CustomIterable())] == [0, 1, 2]
================================================
FILE: tests/test_config.py
================================================
import os
from pathlib import Path
from typing import Any
import pytest
from typing_extensions import assert_type
from starlette.config import Config, Environ, EnvironError
from starlette.datastructures import URL, Secret
def test_config_types() -> None:
"""
We use `assert_type` to test the types returned by Config via mypy.
"""
config = Config(environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"})
assert_type(config("STR"), str)
assert_type(config("STR_DEFAULT", default=""), str)
assert_type(config("STR_CAST", cast=str), str)
assert_type(config("STR_NONE", default=None), str | None)
assert_type(config("STR_CAST_NONE", cast=str, default=None), str | None)
assert_type(config("STR_CAST_STR", cast=str, default=""), str)
assert_type(config("BOOL", cast=bool), bool)
assert_type(config("BOOL_DEFAULT", cast=bool, default=False), bool)
assert_type(config("BOOL_NONE", cast=bool, default=None), bool | None)
def cast_to_int(v: Any) -> int:
return int(v)
# our type annotations allow these `cast` and `default` configurations, but
# the code will error at runtime.
with pytest.raises(ValueError):
config("INT_CAST_DEFAULT_STR", cast=cast_to_int, default="true")
with pytest.raises(ValueError):
config("INT_DEFAULT_STR", cast=int, default="true")
def test_config(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None:
path = os.path.join(tmpdir, ".env")
with open(path, "w") as file:
file.write("# Do not commit to source control\n")
file.write("DATABASE_URL=postgres://user:pass@localhost/dbname\n")
file.write("REQUEST_HOSTNAME=example.com\n")
file.write("SECRET_KEY=12345\n")
file.write("BOOL_AS_INT=0\n")
file.write("\n")
file.write("\n")
config = Config(path, environ={"DEBUG": "true"})
def cast_to_int(v: Any) -> int:
return int(v)
DEBUG = config("DEBUG", cast=bool)
DATABASE_URL = config("DATABASE_URL", cast=URL)
REQUEST_TIMEOUT = config("REQUEST_TIMEOUT", cast=int, default=10)
REQUEST_HOSTNAME = config("REQUEST_HOSTNAME")
MAIL_HOSTNAME = config("MAIL_HOSTNAME", default=None)
SECRET_KEY = config("SECRET_KEY", cast=Secret)
UNSET_SECRET = config("UNSET_SECRET", cast=Secret, default=None)
EMPTY_SECRET = config("EMPTY_SECRET", cast=Secret, default="")
assert config("BOOL_AS_INT", cast=bool) is False
assert config("BOOL_AS_INT", cast=cast_to_int) == 0
assert config("DEFAULTED_BOOL", cast=cast_to_int, default=True) == 1
assert DEBUG is True
assert DATABASE_URL.path == "/dbname"
assert DATABASE_URL.password == "pass"
assert DATABASE_URL.username == "user"
assert REQUEST_TIMEOUT == 10
assert REQUEST_HOSTNAME == "example.com"
assert MAIL_HOSTNAME is None
assert repr(SECRET_KEY) == "Secret('**********')"
assert str(SECRET_KEY) == "12345"
assert bool(SECRET_KEY)
assert not bool(EMPTY_SECRET)
assert not bool(UNSET_SECRET)
with pytest.raises(KeyError):
config.get("MISSING")
with pytest.raises(ValueError):
config.get("DEBUG", cast=int)
with pytest.raises(ValueError):
config.get("REQUEST_HOSTNAME", cast=bool)
config = Config(Path(path))
REQUEST_HOSTNAME = config("REQUEST_HOSTNAME")
assert REQUEST_HOSTNAME == "example.com"
config = Config()
monkeypatch.setenv("STARLETTE_EXAMPLE_TEST", "123")
monkeypatch.setenv("BOOL_AS_INT", "1")
assert config.get("STARLETTE_EXAMPLE_TEST", cast=int) == 123
assert config.get("BOOL_AS_INT", cast=bool) is True
monkeypatch.setenv("BOOL_AS_INT", "2")
with pytest.raises(ValueError):
config.get("BOOL_AS_INT", cast=bool)
def test_missing_env_file_raises(tmpdir: Path) -> None:
path = os.path.join(tmpdir, ".env")
with pytest.warns(UserWarning, match=f"Config file '{path}' not found."):
Config(path)
def test_environ() -> None:
environ = Environ()
# We can mutate the environ at this point.
environ["TESTING"] = "True"
environ["GONE"] = "123"
del environ["GONE"]
# We can read the environ.
assert environ["TESTING"] == "True"
assert "GONE" not in environ
# We cannot mutate these keys now that we've read them.
with pytest.raises(EnvironError):
environ["TESTING"] = "False"
with pytest.raises(EnvironError):
del environ["GONE"]
# Test coverage of abstract methods for MutableMapping.
environ = Environ()
assert list(iter(environ)) == list(iter(os.environ))
assert len(environ) == len(os.environ)
def test_config_with_env_prefix(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None:
config = Config(environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_")
assert config.get("DEBUG") == "value"
with pytest.raises(KeyError):
config.get("ENVIRONMENT")
def test_config_with_encoding(tmpdir: Path) -> None:
path = tmpdir / ".env"
path.write_text("MESSAGE=Hello 世界\n", encoding="utf-8")
config = Config(path, encoding="utf-8")
assert config.get("MESSAGE") == "Hello 世界"
================================================
FILE: tests/test_convertors.py
================================================
from collections.abc import Iterator
from datetime import datetime
from uuid import UUID
import pytest
from starlette import convertors
from starlette.convertors import Convertor, register_url_convertor
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route, Router
from tests.types import TestClientFactory
@pytest.fixture(scope="module", autouse=True)
def refresh_convertor_types() -> Iterator[None]:
convert_types = convertors.CONVERTOR_TYPES.copy()
yield
convertors.CONVERTOR_TYPES = convert_types
class DateTimeConvertor(Convertor[datetime]):
regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}(.[0-9]+)?"
def convert(self, value: str) -> datetime:
return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S")
def to_string(self, value: datetime) -> str:
return value.strftime("%Y-%m-%dT%H:%M:%S")
@pytest.fixture(scope="function")
def app() -> Router:
register_url_convertor("datetime", DateTimeConvertor())
def datetime_convertor(request: Request) -> JSONResponse:
param = request.path_params["param"]
assert isinstance(param, datetime)
return JSONResponse({"datetime": param.strftime("%Y-%m-%dT%H:%M:%S")})
return Router(
routes=[
Route(
"/datetime/{param:datetime}",
endpoint=datetime_convertor,
name="datetime-convertor",
)
]
)
def test_datetime_convertor(test_client_factory: TestClientFactory, app: Router) -> None:
client = test_client_factory(app)
response = client.get("/datetime/2020-01-01T00:00:00")
assert response.json() == {"datetime": "2020-01-01T00:00:00"}
assert (
app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0)) == "/datetime/1996-01-22T23:00:00"
)
@pytest.mark.parametrize("param, status_code", [("1.0", 200), ("1-0", 404)])
def test_default_float_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None:
def float_convertor(request: Request) -> JSONResponse:
param = request.path_params["param"]
assert isinstance(param, float)
return JSONResponse({"float": param})
app = Router(routes=[Route("/{param:float}", endpoint=float_convertor)])
client = test_client_factory(app)
response = client.get(f"/{param}")
assert response.status_code == status_code
@pytest.mark.parametrize(
"param, status_code",
[
("00000000-aaaa-ffff-9999-000000000000", 200),
("00000000aaaaffff9999000000000000", 200),
("00000000-AAAA-FFFF-9999-000000000000", 200),
("00000000AAAAFFFF9999000000000000", 200),
("not-a-uuid", 404),
],
)
def test_default_uuid_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None:
def uuid_convertor(request: Request) -> JSONResponse:
param = request.path_params["param"]
assert isinstance(param, UUID)
return JSONResponse("ok")
app = Router(routes=[Route("/{param:uuid}", endpoint=uuid_convertor)])
client = test_client_factory(app)
response = client.get(f"/{param}")
assert response.status_code == status_code
================================================
FILE: tests/test_datastructures.py
================================================
import io
from tempfile import SpooledTemporaryFile
from typing import BinaryIO
import pytest
from starlette.datastructures import (
URL,
CommaSeparatedStrings,
FormData,
Headers,
MultiDict,
MutableHeaders,
QueryParams,
UploadFile,
)
def test_url() -> None:
u = URL("https://example.org:123/path/to/somewhere?abc=123#anchor")
assert u.scheme == "https"
assert u.hostname == "example.org"
assert u.port == 123
assert u.netloc == "example.org:123"
assert u.username is None
assert u.password is None
assert u.path == "/path/to/somewhere"
assert u.query == "abc=123"
assert u.fragment == "anchor"
new = u.replace(scheme="http")
assert new == "http://example.org:123/path/to/somewhere?abc=123#anchor"
assert new.scheme == "http"
new = u.replace(port=None)
assert new == "https://example.org/path/to/somewhere?abc=123#anchor"
assert new.port is None
new = u.replace(hostname="example.com")
assert new == "https://example.com:123/path/to/somewhere?abc=123#anchor"
assert new.hostname == "example.com"
ipv6_url = URL("https://[fe::2]:12345")
new = ipv6_url.replace(port=8080)
assert new == "https://[fe::2]:8080"
new = ipv6_url.replace(username="username", password="password")
assert new == "https://username:password@[fe::2]:12345"
assert new.netloc == "username:password@[fe::2]:12345"
ipv6_url = URL("https://[fe::2]")
new = ipv6_url.replace(port=123)
assert new == "https://[fe::2]:123"
url = URL("http://u:p@host/")
assert url.replace(hostname="bar") == URL("http://u:p@bar/")
url = URL("http://u:p@host:80")
assert url.replace(port=88) == URL("http://u:p@host:88")
url = URL("http://host:80")
assert url.replace(username="u") == URL("http://u@host:80")
def test_url_query_params() -> None:
u = URL("https://example.org/path/?page=3")
assert u.query == "page=3"
u = u.include_query_params(page=4)
assert str(u) == "https://example.org/path/?page=4"
u = u.include_query_params(search="testing")
assert str(u) == "https://example.org/path/?page=4&search=testing"
u = u.replace_query_params(order="name")
assert str(u) == "https://example.org/path/?order=name"
u = u.remove_query_params("order")
assert str(u) == "https://example.org/path/"
u = u.include_query_params(page=4, search="testing")
assert str(u) == "https://example.org/path/?page=4&search=testing"
u = u.remove_query_params(["page", "search"])
assert str(u) == "https://example.org/path/"
def test_hidden_password() -> None:
u = URL("https://example.org/path/to/somewhere")
assert repr(u) == "URL('https://example.org/path/to/somewhere')"
u = URL("https://username@example.org/path/to/somewhere")
assert repr(u) == "URL('https://username@example.org/path/to/somewhere')"
u = URL("https://username:password@example.org/path/to/somewhere")
assert repr(u) == "URL('https://username:********@example.org/path/to/somewhere')"
def test_csv() -> None:
csv = CommaSeparatedStrings('"localhost", "127.0.0.1", 0.0.0.0')
assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"]
assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])"
assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'"
assert csv[0] == "localhost"
assert len(csv) == 3
csv = CommaSeparatedStrings("'localhost', '127.0.0.1', 0.0.0.0")
assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"]
assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])"
assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'"
csv = CommaSeparatedStrings("localhost, 127.0.0.1, 0.0.0.0")
assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"]
assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])"
assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'"
csv = CommaSeparatedStrings(["localhost", "127.0.0.1", "0.0.0.0"])
assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"]
assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])"
assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'"
def test_url_from_scope() -> None:
u = URL(scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []})
assert u == "/path/to/somewhere?abc=123"
assert repr(u) == "URL('/path/to/somewhere?abc=123')"
u = URL(
scope={
"scheme": "https",
"server": ("example.org", 123),
"path": "/path/to/somewhere",
"query_string": b"abc=123",
"headers": [],
}
)
assert u == "https://example.org:123/path/to/somewhere?abc=123"
assert repr(u) == "URL('https://example.org:123/path/to/somewhere?abc=123')"
u = URL(
scope={
"scheme": "https",
"server": ("example.org", 443),
"path": "/path/to/somewhere",
"query_string": b"abc=123",
"headers": [],
}
)
assert u == "https://example.org/path/to/somewhere?abc=123"
assert repr(u) == "URL('https://example.org/path/to/somewhere?abc=123')"
u = URL(
scope={
"scheme": "http",
"path": "/some/path",
"query_string": b"query=string",
"headers": [
(b"content-type", b"text/html"),
(b"host", b"example.com:8000"),
(b"accept", b"text/html"),
],
}
)
assert u == "http://example.com:8000/some/path?query=string"
assert repr(u) == "URL('http://example.com:8000/some/path?query=string')"
def test_headers() -> None:
h = Headers(raw=[(b"a", b"123"), (b"a", b"456"), (b"b", b"789")])
assert "a" in h
assert "A" in h
assert "b" in h
assert "B" in h
assert "c" not in h
assert h["a"] == "123"
assert h.get("a") == "123"
assert h.get("nope", default=None) is None
assert h.getlist("a") == ["123", "456"]
assert h.keys() == ["a", "a", "b"]
assert h.values() == ["123", "456", "789"]
assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")]
assert list(h) == ["a", "a", "b"]
assert dict(h) == {"a": "123", "b": "789"}
assert repr(h) == "Headers(raw=[(b'a', b'123'), (b'a', b'456'), (b'b', b'789')])"
assert h == Headers(raw=[(b"a", b"123"), (b"b", b"789"), (b"a", b"456")])
assert h != [(b"a", b"123"), (b"A", b"456"), (b"b", b"789")]
h = Headers({"a": "123", "b": "789"})
assert h["A"] == "123"
assert h["B"] == "789"
assert h.raw == [(b"a", b"123"), (b"b", b"789")]
assert repr(h) == "Headers({'a': '123', 'b': '789'})"
def test_mutable_headers() -> None:
h = MutableHeaders()
assert dict(h) == {}
h["a"] = "1"
assert dict(h) == {"a": "1"}
h["a"] = "2"
assert dict(h) == {"a": "2"}
h.setdefault("a", "3")
assert dict(h) == {"a": "2"}
h.setdefault("b", "4")
assert dict(h) == {"a": "2", "b": "4"}
del h["a"]
assert dict(h) == {"b": "4"}
assert h.raw == [(b"b", b"4")]
def test_mutable_headers_merge() -> None:
h = MutableHeaders()
h = h | MutableHeaders({"a": "1"})
assert isinstance(h, MutableHeaders)
assert dict(h) == {"a": "1"}
assert h.items() == [("a", "1")]
assert h.raw == [(b"a", b"1")]
def test_mutable_headers_merge_dict() -> None:
h = MutableHeaders()
h = h | {"a": "1"}
assert isinstance(h, MutableHeaders)
assert dict(h) == {"a": "1"}
assert h.items() == [("a", "1")]
assert h.raw == [(b"a", b"1")]
def test_mutable_headers_update() -> None:
h = MutableHeaders()
h |= MutableHeaders({"a": "1"})
assert isinstance(h, MutableHeaders)
assert dict(h) == {"a": "1"}
assert h.items() == [("a", "1")]
assert h.raw == [(b"a", b"1")]
def test_mutable_headers_update_dict() -> None:
h = MutableHeaders()
h |= {"a": "1"}
assert isinstance(h, MutableHeaders)
assert dict(h) == {"a": "1"}
assert h.items() == [("a", "1")]
assert h.raw == [(b"a", b"1")]
def test_mutable_headers_merge_not_mapping() -> None:
h = MutableHeaders()
with pytest.raises(TypeError):
h |= {"not_mapping"} # type: ignore[arg-type]
with pytest.raises(TypeError):
h | {"not_mapping"} # type: ignore[operator]
def test_headers_mutablecopy() -> None:
h = Headers(raw=[(b"a", b"123"), (b"a", b"456"), (b"b", b"789")])
c = h.mutablecopy()
assert c.items() == [("a", "123"), ("a", "456"), ("b", "789")]
c["a"] = "abc"
assert c.items() == [("a", "abc"), ("b", "789")]
def test_mutable_headers_from_scope() -> None:
# "headers" in scope must not necessarily be a list
h = MutableHeaders(scope={"headers": ((b"a", b"1"),)})
assert dict(h) == {"a": "1"}
h.update({"b": "2"})
assert dict(h) == {"a": "1", "b": "2"}
assert list(h.items()) == [("a", "1"), ("b", "2")]
assert list(h.raw) == [(b"a", b"1"), (b"b", b"2")]
def test_url_blank_params() -> None:
q = QueryParams("a=123&abc&def&b=456")
assert "a" in q
assert "abc" in q
assert "def" in q
assert "b" in q
val = q.get("abc")
assert val is not None
assert len(val) == 0
assert len(q["a"]) == 3
assert list(q.keys()) == ["a", "abc", "def", "b"]
def test_queryparams() -> None:
q = QueryParams("a=123&a=456&b=789")
assert "a" in q
assert "A" not in q
assert "c" not in q
assert q["a"] == "456"
assert q.get("a") == "456"
assert q.get("nope", default=None) is None
assert q.getlist("a") == ["123", "456"]
assert list(q.keys()) == ["a", "b"]
assert list(q.values()) == ["456", "789"]
assert list(q.items()) == [("a", "456"), ("b", "789")]
assert len(q) == 2
assert list(q) == ["a", "b"]
assert dict(q) == {"a": "456", "b": "789"}
assert str(q) == "a=123&a=456&b=789"
assert repr(q) == "QueryParams('a=123&a=456&b=789')"
assert QueryParams({"a": "123", "b": "456"}) == QueryParams([("a", "123"), ("b", "456")])
assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456")
assert QueryParams({"a": "123", "b": "456"}) == QueryParams({"b": "456", "a": "123"})
assert QueryParams() == QueryParams({})
assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456")
assert QueryParams({"a": "123", "b": "456"}) != "invalid"
q = QueryParams([("a", "123"), ("a", "456")])
assert QueryParams(q) == q
@pytest.mark.anyio
async def test_upload_file_file_input() -> None:
"""Test passing file/stream into the UploadFile constructor"""
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream, size=4)
assert await file.read() == b"data"
assert file.size == 4
await file.write(b" and more data!")
assert await file.read() == b""
assert file.size == 19
await file.seek(0)
assert await file.read() == b"data and more data!"
@pytest.mark.anyio
async def test_upload_file_without_size() -> None:
"""Test passing file/stream into the UploadFile constructor without size"""
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream)
assert await file.read() == b"data"
assert file.size is None
await file.write(b" and more data!")
assert await file.read() == b""
assert file.size is None
await file.seek(0)
assert await file.read() == b"data and more data!"
@pytest.mark.anyio
@pytest.mark.parametrize("max_size", [1, 1024], ids=["rolled", "unrolled"])
async def test_uploadfile_rolling(max_size: int) -> None:
"""Test that we can r/w to a SpooledTemporaryFile
managed by UploadFile before and after it rolls to disk
"""
stream: BinaryIO = SpooledTemporaryFile( # type: ignore[assignment]
max_size=max_size
)
file = UploadFile(filename="file", file=stream, size=0)
assert await file.read() == b""
assert file.size == 0
await file.write(b"data")
assert await file.read() == b""
assert file.size == 4
await file.seek(0)
assert await file.read() == b"data"
await file.write(b" more")
assert await file.read() == b""
assert file.size == 9
await file.seek(0)
assert await file.read() == b"data more"
assert file.size == 9
await file.close()
def test_formdata() -> None:
stream = io.BytesIO(b"data")
upload = UploadFile(filename="file", file=stream, size=4)
form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
assert "a" in form
assert "A" not in form
assert "c" not in form
assert form["a"] == "456"
assert form.get("a") == "456"
assert form.get("nope", default=None) is None
assert form.getlist("a") == ["123", "456"]
assert list(form.keys()) == ["a", "b"]
assert list(form.values()) == ["456", upload]
assert list(form.items()) == [("a", "456"), ("b", upload)]
assert len(form) == 2
assert list(form) == ["a", "b"]
assert dict(form) == {"a": "456", "b": upload}
assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
assert FormData(form) == form
assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")])
assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}
@pytest.mark.anyio
async def test_upload_file_repr() -> None:
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream, size=4)
assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))"
@pytest.mark.anyio
async def test_upload_file_repr_headers() -> None:
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"}))
assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
def test_multidict() -> None:
q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")])
assert "a" in q
assert "A" not in q
assert "c" not in q
assert q["a"] == "456"
assert q.get("a") == "456"
assert q.get("nope", default=None) is None
assert q.getlist("a") == ["123", "456"]
assert list(q.keys()) == ["a", "b"]
assert list(q.values()) == ["456", "789"]
assert list(q.items()) == [("a", "456"), ("b", "789")]
assert len(q) == 2
assert list(q) == ["a", "b"]
assert dict(q) == {"a": "456", "b": "789"}
assert str(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])"
assert repr(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])"
assert MultiDict({"a": "123", "b": "456"}) == MultiDict([("a", "123"), ("b", "456")])
assert MultiDict({"a": "123", "b": "456"}) == MultiDict({"b": "456", "a": "123"})
assert MultiDict() == MultiDict({})
assert MultiDict({"a": "123", "b": "456"}) != "invalid"
q = MultiDict([("a", "123"), ("a", "456")])
assert MultiDict(q) == q
q = MultiDict([("a", "123"), ("a", "456")])
q["a"] = "789"
assert q["a"] == "789"
assert q.get("a") == "789"
assert q.getlist("a") == ["789"]
q = MultiDict([("a", "123"), ("a", "456")])
del q["a"]
assert q.get("a") is None
assert repr(q) == "MultiDict([])"
q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")])
assert q.pop("a") == "456"
assert q.get("a", None) is None
assert repr(q) == "MultiDict([('b', '789')])"
q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")])
item = q.popitem()
assert q.get(item[0]) is None
q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")])
assert q.poplist("a") == ["123", "456"]
assert q.get("a") is None
assert repr(q) == "MultiDict([('b', '789')])"
q = MultiDict([("a", "123"), ("a", "456"), ("b", "789")])
q.clear()
assert q.get("a") is None
assert repr(q) == "MultiDict([])"
q = MultiDict([("a", "123")])
q.setlist("a", ["456", "789"])
assert q.getlist("a") == ["456", "789"]
q.setlist("b", [])
assert "b" not in q
q = MultiDict([("a", "123")])
assert q.setdefault("a", "456") == "123"
assert q.getlist("a") == ["123"]
assert q.setdefault("b", "456") == "456"
assert q.getlist("b") == ["456"]
assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])"
q = MultiDict([("a", "123")])
q.append("a", "456")
assert q.getlist("a") == ["123", "456"]
assert repr(q) == "MultiDict([('a', '123'), ('a', '456')])"
q = MultiDict([("a", "123"), ("b", "456")])
q.update({"a": "789"})
assert q.getlist("a") == ["789"]
assert q == MultiDict([("a", "789"), ("b", "456")])
q = MultiDict([("a", "123"), ("b", "456")])
q.update(q)
assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])"
q = MultiDict([("a", "123"), ("a", "456")])
q.update([("a", "123")])
assert q.getlist("a") == ["123"]
q.update([("a", "456")], a="789", b="123")
assert q == MultiDict([("a", "456"), ("a", "789"), ("b", "123")])
================================================
FILE: tests/test_endpoints.py
================================================
from collections.abc import Iterator
import pytest
from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router
from starlette.testclient import TestClient
from starlette.websockets import WebSocket
from tests.types import TestClientFactory
class Homepage(HTTPEndpoint):
async def get(self, request: Request) -> PlainTextResponse:
username = request.path_params.get("username")
if username is None:
return PlainTextResponse("Hello, world!")
return PlainTextResponse(f"Hello, {username}!")
app = Router(routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)])
@pytest.fixture
def client(test_client_factory: TestClientFactory) -> Iterator[TestClient]:
with test_client_factory(app) as client:
yield client
def test_http_endpoint_route(client: TestClient) -> None:
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world!"
def test_http_endpoint_route_path_params(client: TestClient) -> None:
response = client.get("/tomchristie")
assert response.status_code == 200
assert response.text == "Hello, tomchristie!"
def test_http_endpoint_route_method(client: TestClient) -> None:
response = client.post("/")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
assert response.headers["allow"] == "GET"
def test_websocket_endpoint_on_connect(test_client_factory: TestClientFactory) -> None:
class WebSocketApp(WebSocketEndpoint):
async def on_connect(self, websocket: WebSocket) -> None:
assert websocket["subprotocols"] == ["soap", "wamp"]
await websocket.accept(subprotocol="wamp")
client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws", subprotocols=["soap", "wamp"]) as websocket:
assert websocket.accepted_subprotocol == "wamp"
def test_websocket_endpoint_on_receive_bytes(
test_client_factory: TestClientFactory,
) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = "bytes"
async def on_receive(self, websocket: WebSocket, data: bytes) -> None:
await websocket.send_bytes(b"Message bytes was: " + data)
client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_bytes(b"Hello, world!")
_bytes = websocket.receive_bytes()
assert _bytes == b"Message bytes was: Hello, world!"
with pytest.raises(RuntimeError):
with client.websocket_connect("/ws") as websocket:
websocket.send_text("Hello world")
def test_websocket_endpoint_on_receive_json(
test_client_factory: TestClientFactory,
) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = "json"
async def on_receive(self, websocket: WebSocket, data: str) -> None:
await websocket.send_json({"message": data})
client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_json({"hello": "world"})
data = websocket.receive_json()
assert data == {"message": {"hello": "world"}}
with pytest.raises(RuntimeError):
with client.websocket_connect("/ws") as websocket:
websocket.send_text("Hello world")
def test_websocket_endpoint_on_receive_json_binary(
test_client_factory: TestClientFactory,
) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = "json"
async def on_receive(self, websocket: WebSocket, data: str) -> None:
await websocket.send_json({"message": data}, mode="binary")
client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_json({"hello": "world"}, mode="binary")
data = websocket.receive_json(mode="binary")
assert data == {"message": {"hello": "world"}}
def test_websocket_endpoint_on_receive_text(
test_client_factory: TestClientFactory,
) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = "text"
async def on_receive(self, websocket: WebSocket, data: str) -> None:
await websocket.send_text(f"Message text was: {data}")
client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_text("Hello, world!")
_text = websocket.receive_text()
assert _text == "Message text was: Hello, world!"
with pytest.raises(RuntimeError):
with client.websocket_connect("/ws") as websocket:
websocket.send_bytes(b"Hello world")
def test_websocket_endpoint_on_default(test_client_factory: TestClientFactory) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = None
async def on_receive(self, websocket: WebSocket, data: str) -> None:
await websocket.send_text(f"Message text was: {data}")
client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_text("Hello, world!")
_text = websocket.receive_text()
assert _text == "Message text was: Hello, world!"
def test_websocket_endpoint_on_disconnect(
test_client_factory: TestClientFactory,
) -> None:
class WebSocketApp(WebSocketEndpoint):
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
assert close_code == 1001
await websocket.close(code=close_code)
client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.close(code=1001)
================================================
FILE: tests/test_exceptions.py
================================================
from collections.abc import Generator
from typing import Any
import pytest
from pytest import MonkeyPatch
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import Receive, Scope, Send
from tests.types import TestClientFactory
def raise_runtime_error(request: Request) -> None:
raise RuntimeError("Yikes")
def not_acceptable(request: Request) -> None:
raise HTTPException(status_code=406)
def no_content(request: Request) -> None:
raise HTTPException(status_code=204)
def not_modified(request: Request) -> None:
raise HTTPException(status_code=304)
def with_headers(request: Request) -> None:
raise HTTPException(status_code=200, headers={"x-potato": "always"})
class BadBodyException(HTTPException):
pass
async def read_body_and_raise_exc(request: Request) -> None:
await request.body()
raise BadBodyException(422)
async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse:
body = await request.body()
return JSONResponse(status_code=422, content={"body": body.decode()})
class HandledExcAfterResponse:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("OK", status_code=200)
await response(scope, receive, send)
raise HTTPException(status_code=406)
router = Router(
routes=[
Route("/runtime_error", endpoint=raise_runtime_error),
Route("/not_acceptable", endpoint=not_acceptable),
Route("/no_content", endpoint=no_content),
Route("/not_modified", endpoint=not_modified),
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]),
]
)
app = ExceptionMiddleware(
router,
handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item]
)
@pytest.fixture
def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client
def test_not_acceptable(client: TestClient) -> None:
response = client.get("/not_acceptable")
assert response.status_code == 406
assert response.text == "Not Acceptable"
def test_no_content(client: TestClient) -> None:
response = client.get("/no_content")
assert response.status_code == 204
assert "content-length" not in response.headers
def test_not_modified(client: TestClient) -> None:
response = client.get("/not_modified")
assert response.status_code == 304
assert response.text == ""
def test_with_headers(client: TestClient) -> None:
response = client.get("/with_headers")
assert response.status_code == 200
assert response.headers["x-potato"] == "always"
def test_websockets_should_raise(client: TestClient) -> None:
with pytest.raises(RuntimeError):
with client.websocket_connect("/runtime_error"):
pass # pragma: no cover
def test_handled_exc_after_response(test_client_factory: TestClientFactory, client: TestClient) -> None:
# A 406 HttpException is raised *after* the response has already been sent.
# The exception middleware should raise a RuntimeError.
with pytest.raises(RuntimeError, match="Caught handled exception, but response already started."):
client.get("/handled_exc_after_response")
# If `raise_server_exceptions=False` then the test client will still allow
# us to see the response as it will have been seen by the client.
allow_200_client = test_client_factory(app, raise_server_exceptions=False)
response = allow_200_client.get("/handled_exc_after_response")
assert response.status_code == 200
assert response.text == "OK"
def test_force_500_response(test_client_factory: TestClientFactory) -> None:
# use a sentinel variable to make sure we actually
# make it into the endpoint and don't get a 500
# from an incorrect ASGI app signature or something
called = False
async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal called
called = True
raise RuntimeError()
force_500_client = test_client_factory(app, raise_server_exceptions=False)
response = force_500_client.get("/")
assert called
assert response.status_code == 500
assert response.text == ""
def test_http_str() -> None:
assert str(HTTPException(status_code=404)) == "404: Not Found"
assert str(HTTPException(404, "Not Found: foo")) == "404: Not Found: foo"
assert str(HTTPException(404, headers={"key": "value"})) == "404: Not Found"
def test_http_repr() -> None:
assert repr(HTTPException(404)) == ("HTTPException(status_code=404, detail='Not Found')")
assert repr(HTTPException(404, detail="Not Found: foo")) == (
"HTTPException(status_code=404, detail='Not Found: foo')"
)
class CustomHTTPException(HTTPException):
pass
assert repr(CustomHTTPException(500, detail="Something custom")) == (
"CustomHTTPException(status_code=500, detail='Something custom')"
)
def test_websocket_str() -> None:
assert str(WebSocketException(1008)) == "1008: "
assert str(WebSocketException(1008, "Policy Violation")) == "1008: Policy Violation"
def test_websocket_repr() -> None:
assert repr(WebSocketException(1008, reason="Policy Violation")) == (
"WebSocketException(code=1008, reason='Policy Violation')"
)
class CustomWebSocketException(WebSocketException):
pass
assert (
repr(CustomWebSocketException(1013, reason="Something custom"))
== "CustomWebSocketException(code=1013, reason='Something custom')"
)
def test_request_in_app_and_handler_is_the_same_object(client: TestClient) -> None:
response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!")
assert response.status_code == 422
assert response.json() == {"body": "Hello!"}
def test_http_exception_does_not_use_threadpool(client: TestClient, monkeypatch: MonkeyPatch) -> None:
"""
Verify that handling HTTPException does not invoke run_in_threadpool,
confirming the handler correctly runs in the main async context.
"""
from starlette import _exception_handler
# Replace run_in_threadpool with a function that raises an error
def mock_run_in_threadpool(*args: Any, **kwargs: Any) -> None:
pytest.fail("run_in_threadpool should not be called for HTTP exceptions") # pragma: no cover
# Apply the monkeypatch only during this test
monkeypatch.setattr(_exception_handler, "run_in_threadpool", mock_run_in_threadpool)
# This should succeed because http_exception is async and won't use run_in_threadpool
response = client.get("/not_acceptable")
assert response.status_code == 406
def test_handlers_annotations() -> None:
"""Check that async exception handlers are accepted by type checkers.
We annotate the handlers' exceptions with plain `Exception` to avoid variance issues
when using other exception types.
"""
async def async_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
raise NotImplementedError
def sync_catch_all_handler(request: Request, exc: Exception) -> JSONResponse:
raise NotImplementedError
ExceptionMiddleware(router, handlers={Exception: sync_catch_all_handler})
ExceptionMiddleware(router, handlers={Exception: async_catch_all_handler})
================================================
FILE: tests/test_formparsers.py
================================================
from __future__ import annotations
import os
import threading
from collections.abc import Generator
from contextlib import AbstractContextManager, nullcontext as does_not_raise
from io import BytesIO
from pathlib import Path
from tempfile import SpooledTemporaryFile
from typing import Any, ClassVar
from unittest import mock
import pytest
from starlette.applications import Starlette
from starlette.datastructures import UploadFile
from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
from starlette.types import ASGIApp, Receive, Scope, Send
from tests.types import TestClientFactory
class ForceMultipartDict(dict[Any, Any]):
def __bool__(self) -> bool:
return True
# FORCE_MULTIPART is an empty dict that boolean-evaluates as `True`.
FORCE_MULTIPART = ForceMultipartDict()
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form()
output: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
output[key] = {
"filename": value.filename,
"size": value.size,
"content": content.decode(),
"content_type": value.content_type,
}
else:
output[key] = value
await request.close()
response = JSONResponse(output)
await response(scope, receive, send)
async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form()
output: dict[str, list[Any]] = {}
for key, value in data.multi_items():
if key not in output:
output[key] = []
if isinstance(value, UploadFile):
content = await value.read()
output[key].append(
{
"filename": value.filename,
"size": value.size,
"content": content.decode(),
"content_type": value.content_type,
}
)
else:
output[key].append(value)
await request.close()
response = JSONResponse(output)
await response(scope, receive, send)
async def app_with_headers(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form()
output: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
output[key] = {
"filename": value.filename,
"size": value.size,
"content": content.decode(),
"content_type": value.content_type,
"headers": list(value.headers.items()),
}
else:
output[key] = value
await request.close()
response = JSONResponse(output)
await response(scope, receive, send)
async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
# Read bytes, to force request.stream() to return the already parsed body
await request.body()
data = await request.form()
output = {}
for key, value in data.items():
output[key] = value
await request.close()
response = JSONResponse(output)
await response(scope, receive, send)
async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None:
"""Helper app to monitor what thread the app was called on.
This can later be used to validate thread/event loop operations.
"""
request = Request(scope, receive)
# Make sure we parse the form
await request.form()
await request.close()
# Send back the current thread id
response = JSONResponse({"thread_ident": threading.current_thread().ident})
await response(scope, receive, send)
def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size)
output: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
output[key] = {
"filename": value.filename,
"size": value.size,
"content": content.decode(),
"content_type": value.content_type,
}
else:
output[key] = value
await request.close()
response = JSONResponse(output)
await response(scope, receive, send)
return app
def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART)
assert response.json() == {"some": "data"}
def test_multipart_request_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"")
client = test_client_factory(app)
with open(path, "rb") as f:
response = client.post("/", files={"test": f})
assert response.json() == {
"test": {
"filename": "test.txt",
"size": 14,
"content": "",
"content_type": "text/plain",
}
}
def test_multipart_request_files_with_content_type(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"")
client = test_client_factory(app)
with open(path, "rb") as f:
response = client.post("/", files={"test": ("test.txt", f, "text/plain")})
assert response.json() == {
"test": {
"filename": "test.txt",
"size": 14,
"content": "",
"content_type": "text/plain",
}
}
def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"")
path2 = os.path.join(tmpdir, "test2.txt")
with open(path2, "wb") as file:
file.write(b"")
client = test_client_factory(app)
with open(path1, "rb") as f1, open(path2, "rb") as f2:
response = client.post("/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")})
assert response.json() == {
"test1": {
"filename": "test1.txt",
"size": 15,
"content": "",
"content_type": "text/plain",
},
"test2": {
"filename": "test2.txt",
"size": 15,
"content": "",
"content_type": "text/plain",
},
}
def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"")
path2 = os.path.join(tmpdir, "test2.txt")
with open(path2, "wb") as file:
file.write(b"")
client = test_client_factory(app_with_headers)
with open(path1, "rb") as f1, open(path2, "rb") as f2:
response = client.post(
"/",
files=[
("test1", (None, f1)),
("test2", ("test2.txt", f2, "text/plain", {"x-custom": "f2"})),
],
)
assert response.json() == {
"test1": "",
"test2": {
"filename": "test2.txt",
"size": 15,
"content": "",
"content_type": "text/plain",
"headers": [
[
"content-disposition",
'form-data; name="test2"; filename="test2.txt"',
],
["x-custom", "f2"],
["content-type", "text/plain"],
],
},
}
def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"")
path2 = os.path.join(tmpdir, "test2.txt")
with open(path2, "wb") as file:
file.write(b"")
client = test_client_factory(multi_items_app)
with open(path1, "rb") as f1, open(path2, "rb") as f2:
response = client.post(
"/",
data={"test1": "abc"},
files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))],
)
assert response.json() == {
"test1": [
"abc",
{
"filename": "test1.txt",
"size": 15,
"content": "",
"content_type": "text/plain",
},
{
"filename": "test2.txt",
"size": 15,
"content": "",
"content_type": "text/plain",
},
]
}
def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
data=(
# data
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
b'Content-Disposition: form-data; name="field0"\r\n\r\n'
b"value0\r\n"
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
b'Content-Disposition: form-data; name="file"; filename="file.txt"\r\n'
b"Content-Type: text/plain\r\n\r\n"
b"\r\n"
# data
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"
b'Content-Disposition: form-data; name="field1"\r\n\r\n'
b"value1\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
"filename": "file.txt",
"size": 14,
"content": "",
"content_type": "text/plain",
},
"field0": "value0",
"field1": "value1",
}
class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]):
"""Helper class to track which threads performed the rollover operation.
This is not threadsafe/multi-test safe.
"""
rollover_threads: ClassVar[set[int | None]] = set()
def rollover(self) -> None:
ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
super().rollover()
@pytest.fixture
def mock_spooled_temporary_file() -> Generator[None]:
try:
with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
yield
finally:
ThreadTrackingSpooledTemporaryFile.rollover_threads.clear()
def test_multipart_request_large_file_rollover_in_background_thread(
mock_spooled_temporary_file: None, test_client_factory: TestClientFactory
) -> None:
"""Test that Spooled file rollovers happen in background threads."""
data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1))
client = test_client_factory(app_monitor_thread)
response = client.post("/", files=[("test_large", data)])
assert response.status_code == 200
# Parse the event thread id from the API response and ensure we have one
app_thread_ident = response.json().get("thread_ident")
assert app_thread_ident is not None
# Ensure the app thread was not the same as the rollover one and that a rollover thread exists
assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1
def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
data=(
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'
b"Content-Type: text/plain\r\n\r\n"
b"\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
"filename": "文書.txt",
"size": 14,
"content": "",
"content_type": "text/plain",
}
}
def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
data=(
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n'
b"Content-Type: image/jpeg\r\n\r\n"
b"\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
"filename": "画像.jpg",
"size": 14,
"content": "",
"content_type": "image/jpeg",
}
}
def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
data=(
b"--20b303e711c4ab8c443184ac833ab00f\r\n" # type: ignore
b"Content-Disposition: form-data; "
b'name="value"\r\n\r\n'
b"Transf\xc3\xa9rer\r\n"
b"--20b303e711c4ab8c443184ac833ab00f--\r\n"
),
headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f")},
)
assert response.json() == {"value": "Transférer"}
def test_urlencoded_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "data"})
assert response.json() == {"some": "data"}
def test_no_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/")
assert response.json() == {}
def test_urlencoded_percent_encoding(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "da ta"})
assert response.json() == {"some": "da ta"}
def test_urlencoded_percent_encoding_keys(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"so me": "data"})
assert response.json() == {"so me": "data"}
def test_urlencoded_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app_read_body)
response = client.post("/", data={"some": "data", "second": "key pair"})
assert response.json() == {"some": "data", "second": "key pair"}
def test_multipart_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app_read_body)
response = client.post("/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART)
assert response.json() == {"some": "data", "second": "key pair"}
def test_user_safe_decode_helper() -> None:
result = _user_safe_decode(b"\xc4\x99\xc5\xbc\xc4\x87", "utf-8")
assert result == "ężć"
def test_user_safe_decode_ignores_wrong_charset() -> None:
result = _user_safe_decode(b"abc", "latin-8")
assert result == "abc"
@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
def test_missing_boundary_parameter(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
with expectation:
res = client.post(
"/",
data=(
# file
b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore
b"Content-Type: text/plain\r\n\r\n"
b"\r\n"
),
headers={"Content-Type": "multipart/form-data; charset=utf-8"},
)
assert res.status_code == 400
assert res.text == "Missing boundary in multipart."
@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
def test_missing_name_parameter_on_content_disposition(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
with expectation:
res = client.post(
"/",
data=(
# data
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
b'Content-Disposition: form-data; ="field0"\r\n\r\n'
b"value0\r\n"
),
headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert res.status_code == 400
assert res.text == 'The Content-Disposition header field "name" must be provided.'
@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
def test_too_many_fields_raise(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
fields = []
for i in range(1001):
fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}";\r\n\r\n\r\n')
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
assert res.status_code == 400
assert res.text == "Too many fields. Maximum number of fields is 1000."
@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
def test_too_many_files_raise(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
fields = []
for i in range(1001):
fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n\r\n')
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
assert res.status_code == 400
assert res.text == "Too many files. Maximum number of files is 1000."
@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
def test_too_many_files_single_field_raise(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
fields = []
for i in range(1001):
# This uses the same field name "N" for all files, equivalent to a
# multifile upload form field
fields.append(f'--B\r\nContent-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n\r\n')
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
assert res.status_code == 400
assert res.text == "Too many files. Maximum number of files is 1000."
@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
def test_too_many_files_and_fields_raise(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
fields = []
for i in range(1001):
fields.append(f'--B\r\nContent-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n\r\n')
fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}";\r\n\r\n\r\n')
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
assert res.status_code == 400
assert res.text == "Too many files. Maximum number of files is 1000."
@pytest.mark.parametrize(
"app,expectation",
[
(make_app_max_parts(max_fields=1), pytest.raises(MultiPartException)),
(
Starlette(routes=[Mount("/", app=make_app_max_parts(max_fields=1))]),
does_not_raise(),
),
],
)
def test_max_fields_is_customizable_low_raises(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
fields = []
for i in range(2):
fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}";\r\n\r\n\r\n')
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
assert res.status_code == 400
assert res.text == "Too many fields. Maximum number of fields is 1."
@pytest.mark.parametrize(
"app,expectation",
[
(make_app_max_parts(max_files=1), pytest.raises(MultiPartException)),
(
Starlette(routes=[Mount("/", app=make_app_max_parts(max_files=1))]),
does_not_raise(),
),
],
)
def test_max_files_is_customizable_low_raises(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
fields = []
for i in range(2):
fields.append(f'--B\r\nContent-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n\r\n')
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
assert res.status_code == 400
assert res.text == "Too many files. Maximum number of files is 1."
def test_max_fields_is_customizable_high(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000))
fields = []
for i in range(2000):
fields.append(f'--B\r\nContent-Disposition: form-data; name="N{i}";\r\n\r\n\r\n')
fields.append(f'--B\r\nContent-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n\r\n')
data = "".join(fields).encode("utf-8")
data += b"--B--\r\n"
res = client.post(
"/",
data=data, # type: ignore
headers={"Content-Type": ("multipart/form-data; boundary=B")},
)
assert res.status_code == 200
res_data = res.json()
assert res_data["N1999"] == ""
assert res_data["F1999"] == {
"filename": "F1999",
"size": 0,
"content": "",
"content_type": None,
}
@pytest.mark.parametrize(
"app,expectation",
[
(app, pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
def test_max_part_size_exceeds_limit(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR"
multipart_data = (
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="small"\r\n\r\n'
"small content\r\n"
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="large"\r\n\r\n'
+ ("x" * 1024 * 1024 + "x") # 1MB + 1 byte of data
+ "\r\n"
f"--{boundary}--\r\n"
).encode("utf-8")
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Transfer-Encoding": "chunked",
}
with expectation:
response = client.post("/", data=multipart_data, headers=headers) # type: ignore
assert response.status_code == 400
assert response.text == "Part exceeded maximum size of 1024KB."
@pytest.mark.parametrize(
"app,expectation",
[
(make_app_max_parts(max_part_size=1024 * 10), pytest.raises(MultiPartException)),
(
Starlette(routes=[Mount("/", app=make_app_max_parts(max_part_size=1024 * 10))]),
does_not_raise(),
),
],
)
def test_max_part_size_exceeds_custom_limit(
app: ASGIApp,
expectation: AbstractContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
client = test_client_factory(app)
boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR"
multipart_data = (
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="small"\r\n\r\n'
"small content\r\n"
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="large"\r\n\r\n'
+ ("x" * 1024 * 10 + "x") # 1MB + 1 byte of data
+ "\r\n"
f"--{boundary}--\r\n"
).encode("utf-8")
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Transfer-Encoding": "chunked",
}
with expectation:
response = client.post("/", content=multipart_data, headers=headers)
assert response.status_code == 400
assert response.text == "Part exceeded maximum size of 10KB."
================================================
FILE: tests/test_requests.py
================================================
from __future__ import annotations
import sys
from collections.abc import Iterator
from typing import Any
import anyio
import pytest
from starlette.datastructures import URL, Address, State
from starlette.requests import ClientDisconnect, Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.types import Message, Receive, Scope, Send
from tests.types import TestClientFactory
def test_request_url(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = {"method": request.method, "url": str(request.url)}
response = JSONResponse(data)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/123?a=abc")
assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"}
response = client.get("https://example.org:123/")
assert response.json() == {"method": "GET", "url": "https://example.org:123/"}
def test_request_query_params(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
params = dict(request.query_params)
response = JSONResponse({"params": params})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/?a=123&b=456")
assert response.json() == {"params": {"a": "123", "b": "456"}}
@pytest.mark.skipif(
any(module in sys.modules for module in ("brotli", "brotlicffi")),
reason='urllib3 includes "br" to the "accept-encoding" headers.',
)
def test_request_headers(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
headers = dict(request.headers)
response = JSONResponse({"headers": headers})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/", headers={"host": "example.org"})
assert response.json() == {
"headers": {
"host": "example.org",
"user-agent": "testclient",
"accept-encoding": "gzip, deflate",
"accept": "*/*",
"connection": "keep-alive",
}
}
@pytest.mark.parametrize(
"scope,expected_client",
[
({"client": ["client", 42]}, Address("client", 42)),
({"client": None}, None),
({}, None),
],
)
def test_request_client(scope: Scope, expected_client: Address | None) -> None:
scope.update({"type": "http"}) # required by Request's constructor
client = Request(scope).client
assert client == expected_client
def test_request_body(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"body": ""}
response = client.post("/", json={"a": "123"})
assert response.json() == {"body": '{"a":"123"}'}
response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc"}
def test_request_stream(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = b""
async for chunk in request.stream():
body += chunk
response = JSONResponse({"body": body.decode()})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"body": ""}
response = client.post("/", json={"a": "123"})
assert response.json() == {"body": '{"a":"123"}'}
response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc"}
def test_request_form_urlencoded(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
form = await request.form()
response = JSONResponse({"form": dict(form)})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.post("/", data={"abc": "123 @"})
assert response.json() == {"form": {"abc": "123 @"}}
def test_request_form_context_manager(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
async with request.form() as form:
response = JSONResponse({"form": dict(form)})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.post("/", data={"abc": "123 @"})
assert response.json() == {"form": {"abc": "123 @"}}
def test_request_body_then_stream(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
chunks = b""
async for chunk in request.stream():
chunks += chunk
response = JSONResponse({"body": body.decode(), "stream": chunks.decode()})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc", "stream": "abc"}
def test_request_stream_then_body(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
chunks = b""
async for chunk in request.stream(): # pragma: no branch
chunks += chunk
try:
body = await request.body()
except RuntimeError:
body = b""
response = JSONResponse({"body": body.decode(), "stream": chunks.decode()})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "", "stream": "abc"}
def test_request_json(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.json()
response = JSONResponse({"json": data})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.post("/", json={"a": "123"})
assert response.json() == {"json": {"a": "123"}}
def test_request_scope_interface() -> None:
"""
A Request can be instantiated with a scope, and presents a `Mapping`
interface.
"""
request = Request({"type": "http", "method": "GET", "path": "/abc/"})
assert request["method"] == "GET"
assert dict(request) == {"type": "http", "method": "GET", "path": "/abc/"}
assert len(request) == 3
def test_request_raw_path(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
path = request.scope["path"]
raw_path = request.scope["raw_path"]
response = PlainTextResponse(f"{path}, {raw_path}")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/he%2Fllo")
assert response.text == "/he/llo, b'/he%2Fllo'"
def test_request_without_setting_receive(
test_client_factory: TestClientFactory,
) -> None:
"""
If Request is instantiated without the receive channel, then .body()
is not available.
"""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope)
try:
data = await request.json()
except RuntimeError:
data = "Receive channel not available"
response = JSONResponse({"json": data})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.post("/", json={"a": "123"})
assert response.json() == {"json": "Receive channel not available"}
def test_request_disconnect(
anyio_backend_name: str,
anyio_backend_options: dict[str, Any],
) -> None:
"""
If a client disconnect occurs while reading request body
then ClientDisconnect should be raised.
"""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
await request.body()
async def receiver() -> Message:
return {"type": "http.disconnect"}
scope = {"type": "http", "method": "POST", "path": "/"}
with pytest.raises(ClientDisconnect):
anyio.run(
app, # type: ignore
scope,
receiver,
None,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)
def test_request_is_disconnected(test_client_factory: TestClientFactory) -> None:
"""
If a client disconnect occurs after reading request body
then request will be set disconnected properly.
"""
disconnected_after_response = None
async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal disconnected_after_response
request = Request(scope, receive)
body = await request.body()
disconnected = await request.is_disconnected()
response = JSONResponse({"body": body.decode(), "disconnected": disconnected})
await response(scope, receive, send)
disconnected_after_response = await request.is_disconnected()
client = test_client_factory(app)
response = client.post("/", content="foo")
assert response.json() == {"body": "foo", "disconnected": False}
assert disconnected_after_response
def test_request_state_object() -> None:
scope = {"state": {"old": "foo"}}
s = State(scope["state"])
s.new = "value"
assert s.new == "value"
del s.new
with pytest.raises(AttributeError):
s.new
# Test dictionary-style methods
# Test __setitem__
s["dict_key"] = "dict_value"
assert s["dict_key"] == "dict_value"
assert s.dict_key == "dict_value"
# Test __iter__
s["another_key"] = "another_value"
keys = list(s)
assert "old" in keys
assert "dict_key" in keys
assert "another_key" in keys
# Test __len__
assert len(s) == 3
# Test __delitem__
del s["dict_key"]
assert len(s) == 2
with pytest.raises(KeyError):
s["dict_key"]
def test_request_state(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
request.state.example = 123
response = JSONResponse({"state.example": request.state.example})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/123?a=abc")
assert response.json() == {"state.example": 123}
def test_request_cookies(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
mycookie = request.cookies.get("mycookie")
if mycookie:
response = Response(mycookie, media_type="text/plain")
else:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie("mycookie", "Hello, cookies!")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
response = client.get("/")
assert response.text == "Hello, cookies!"
def test_cookie_lenient_parsing(test_client_factory: TestClientFactory) -> None:
"""
The following test is based on a cookie set by Okta, a well-known authorization
service. It turns out that it's common practice to set cookies that would be
invalid according to the spec.
"""
tough_cookie = (
"provider-oauth-nonce=validAsciiblabla; "
'okta-oauth-redirect-params={"responseType":"code","state":"somestate",'
'"nonce":"somenonce","scopes":["openid","profile","email","phone"],'
'"urls":{"issuer":"https://subdomain.okta.com/oauth2/authServer",'
'"authorizeUrl":"https://subdomain.okta.com/oauth2/authServer/v1/authorize",'
'"userinfoUrl":"https://subdomain.okta.com/oauth2/authServer/v1/userinfo"}}; '
"importantCookie=importantValue; sessionCookie=importantSessionValue"
)
expected_keys = {
"importantCookie",
"okta-oauth-redirect-params",
"provider-oauth-nonce",
"sessionCookie",
}
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/", headers={"cookie": tough_cookie})
result = response.json()
assert len(result["cookies"]) == 4
assert set(result["cookies"].keys()) == expected_keys
# These test cases copied from Tornado's implementation
@pytest.mark.parametrize(
"set_cookie,expected",
[
("chips=ahoy; vienna=finger", {"chips": "ahoy", "vienna": "finger"}),
# all semicolons are delimiters, even within quotes
(
'keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"',
{"keebler": '"E=mc2', "L": '\\"Loves\\"', "fudge": "\\012", "": '"'},
),
# Illegal cookies that have an '=' char in an unquoted value.
("keebler=E=mc2", {"keebler": "E=mc2"}),
# Cookies with ':' character in their name.
("key:term=value:term", {"key:term": "value:term"}),
# Cookies with '[' and ']'.
("a=b; c=[; d=r; f=h", {"a": "b", "c": "[", "d": "r", "f": "h"}),
# Cookies that RFC6265 allows.
("a=b; Domain=example.com", {"a": "b", "Domain": "example.com"}),
# parse_cookie() keeps only the last cookie with the same name.
("a=b; h=i; a=c", {"a": "c", "h": "i"}),
],
)
def test_cookies_edge_cases(
set_cookie: str,
expected: dict[str, str],
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/", headers={"cookie": set_cookie})
result = response.json()
assert result["cookies"] == expected
@pytest.mark.parametrize(
"set_cookie,expected",
[
# Chunks without an equals sign appear as unnamed values per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
(
"abc=def; unnamed; django_language=en",
{"": "unnamed", "abc": "def", "django_language": "en"},
),
# Even a double quote may be an unamed value.
('a=b; "; c=d', {"a": "b", "": '"', "c": "d"}),
# Spaces in names and values, and an equals sign in values.
("a b c=d e = f; gh=i", {"a b c": "d e = f", "gh": "i"}),
# More characters the spec forbids.
('a b,c<>@:/[]?{}=d " =e,f g', {"a b,c<>@:/[]?{}": 'd " =e,f g'}),
# Unicode characters. The spec only allows ASCII.
# ("saint=André Bessette", {"saint": "André Bessette"}),
# Browsers don't send extra whitespace or semicolons in Cookie headers,
# but cookie_parser() should parse whitespace the same way
# document.cookie parses whitespace.
(" = b ; ; = ; c = ; ", {"": "b", "c": ""}),
],
)
def test_cookies_invalid(
set_cookie: str,
expected: dict[str, str],
test_client_factory: TestClientFactory,
) -> None:
"""
Cookie strings that are against the RFC6265 spec but which browsers will send if set
via document.cookie.
"""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/", headers={"cookie": set_cookie})
result = response.json()
assert result["cookies"] == expected
def test_multiple_cookie_headers(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
scope["headers"] = [(b"cookie", b"a=abc"), (b"cookie", b"b=def"), (b"cookie", b"c=ghi")]
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
result = response.json()
assert result["cookies"] == {"a": "abc", "b": "def", "c": "ghi"}
def test_chunked_encoding(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
await response(scope, receive, send)
client = test_client_factory(app)
def post_body() -> Iterator[bytes]:
yield b"foo"
yield b"bar"
response = client.post("/", data=post_body()) # type: ignore
assert response.json() == {"body": "foobar"}
def test_request_send_push_promise(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
# the server is push-enabled
scope["extensions"]["http.response.push"] = {}
request = Request(scope, receive, send)
await request.send_push_promise("/style.css")
response = JSONResponse({"json": "OK"})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"json": "OK"}
def test_request_send_push_promise_without_push_extension(
test_client_factory: TestClientFactory,
) -> None:
"""
If server does not support the `http.response.push` extension,
.send_push_promise() does nothing.
"""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope)
await request.send_push_promise("/style.css")
response = JSONResponse({"json": "OK"})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"json": "OK"}
def test_request_send_push_promise_without_setting_send(
test_client_factory: TestClientFactory,
) -> None:
"""
If Request is instantiated without the send channel, then
.send_push_promise() is not available.
"""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
# the server is push-enabled
scope["extensions"]["http.response.push"] = {}
data = "OK"
request = Request(scope)
try:
await request.send_push_promise("/style.css")
except RuntimeError:
data = "Send channel not available"
response = JSONResponse({"json": data})
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"json": "Send channel not available"}
@pytest.mark.parametrize(
"messages",
[
[{"body": b"123", "more_body": True}, {"body": b""}],
[{"body": b"", "more_body": True}, {"body": b"123"}],
[{"body": b"12", "more_body": True}, {"body": b"3"}],
[
{"body": b"123", "more_body": True},
{"body": b"", "more_body": True},
{"body": b""},
],
],
)
@pytest.mark.anyio
async def test_request_rcv(messages: list[Message]) -> None:
messages = messages.copy()
async def rcv() -> Message:
return {"type": "http.request", **messages.pop(0)}
request = Request({"type": "http"}, rcv)
body = await request.body()
assert body == b"123"
@pytest.mark.anyio
async def test_request_stream_called_twice() -> None:
messages: list[Message] = [
{"type": "http.request", "body": b"1", "more_body": True},
{"type": "http.request", "body": b"2", "more_body": True},
{"type": "http.request", "body": b"3"},
]
async def rcv() -> Message:
return messages.pop(0)
request = Request({"type": "http"}, rcv)
s1 = request.stream()
s2 = request.stream()
msg = await s1.__anext__()
assert msg == b"1"
msg = await s2.__anext__()
assert msg == b"2"
msg = await s1.__anext__()
assert msg == b"3"
# at this point we've consumed the entire body
# so we should not wait for more body (which would hang us forever)
msg = await s1.__anext__()
assert msg == b""
msg = await s2.__anext__()
assert msg == b""
# and now both streams are exhausted
with pytest.raises(StopAsyncIteration):
assert await s2.__anext__()
with pytest.raises(StopAsyncIteration):
await s1.__anext__()
def test_request_url_outside_starlette_context(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
request.url_for("index")
client = test_client_factory(app)
with pytest.raises(
RuntimeError,
match="The `url_for` method can only be used inside a Starlette application or with a router.",
):
client.get("/")
def test_request_url_starlette_context(test_client_factory: TestClientFactory) -> None:
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.routing import Route
from starlette.types import ASGIApp
url_for = None
async def homepage(request: Request) -> Response:
return PlainTextResponse("Hello, world!")
class CustomMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
nonlocal url_for
request = Request(scope, receive)
url_for = request.url_for("homepage")
await self.app(scope, receive, send)
app = Starlette(routes=[Route("/home", homepage)], middleware=[Middleware(CustomMiddleware)])
client = test_client_factory(app)
client.get("/home")
assert url_for == URL("http://testserver/home")
================================================
FILE: tests/test_responses.py
================================================
from __future__ import annotations
import datetime as dt
import sys
import time
from collections.abc import AsyncGenerator, AsyncIterator, Iterator
from dataclasses import dataclass
from http.cookies import SimpleCookie
from pathlib import Path
from typing import Any
import anyio
import pytest
from python_multipart import MultipartParser
from starlette import status
from starlette.background import BackgroundTask
from starlette.datastructures import Headers
from starlette.requests import ClientDisconnect, Request
from starlette.responses import FileResponse, JSONResponse, RedirectResponse, Response, StreamingResponse
from starlette.testclient import TestClient
from starlette.types import Message, Receive, Scope, Send
from tests.types import TestClientFactory
def test_text_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("hello, world", media_type="text/plain")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "hello, world"
def test_bytes_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response(b"xxxxx", media_type="image/png")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.content == b"xxxxx"
def test_json_none_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = JSONResponse(None)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.json() is None
assert response.content == b"null"
def test_redirect_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = Response("hello, world", media_type="text/plain")
else:
response = RedirectResponse("/")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/redirect")
assert response.text == "hello, world"
assert response.url == "http://testserver/"
def test_quoting_redirect_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/I ♥ Starlette/":
response = Response("hello, world", media_type="text/plain")
else:
response = RedirectResponse("/I ♥ Starlette/")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/redirect")
assert response.text == "hello, world"
assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"
def test_redirect_response_content_length_header(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = Response("hello", media_type="text/plain") # pragma: no cover
else:
response = RedirectResponse("/")
await response(scope, receive, send)
client: TestClient = test_client_factory(app)
response = client.request("GET", "/redirect", follow_redirects=False)
assert response.url == "http://testserver/redirect"
assert response.headers["content-length"] == "0"
def test_streaming_response(test_client_factory: TestClientFactory) -> None:
filled_by_bg_task = ""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "
await anyio.sleep(0)
async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
nonlocal filled_by_bg_task
async for thing in numbers(start, stop):
filled_by_bg_task = filled_by_bg_task + thing
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
generator = numbers(1, 5)
response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task)
await response(scope, receive, send)
assert filled_by_bg_task == ""
client = test_client_factory(app)
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"
assert filled_by_bg_task == "6, 7, 8, 9"
def test_streaming_response_custom_iterator(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
class CustomAsyncIterator:
def __init__(self) -> None:
self._called = 0
def __aiter__(self) -> AsyncIterator[str]:
return self
async def __anext__(self) -> str:
if self._called == 5:
raise StopAsyncIteration()
self._called += 1
return str(self._called)
response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "12345"
def test_streaming_response_custom_iterable(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
class CustomAsyncIterable:
async def __aiter__(self) -> AsyncIterator[str | bytes]:
for i in range(5):
yield str(i + 1)
response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "12345"
def test_sync_streaming_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
def numbers(minimum: int, maximum: int) -> Iterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "
generator = numbers(1, 5)
response = StreamingResponse(generator, media_type="text/plain")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"
def test_response_headers(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
headers = {"x-header-1": "123", "x-header-2": "456"}
response = Response("hello, world", media_type="text/plain", headers=headers)
response.headers["x-header-2"] = "789"
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.headers["x-header-1"] == "123"
assert response.headers["x-header-2"] == "789"
def test_response_phrase(test_client_factory: TestClientFactory) -> None:
app = Response(status_code=204)
client = test_client_factory(app)
response = client.get("/")
assert response.reason_phrase == "No Content"
app = Response(b"", status_code=123)
client = test_client_factory(app)
response = client.get("/")
assert response.reason_phrase == ""
def test_file_response(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
content = b"" * 1000
path.write_bytes(content)
filled_by_bg_task = ""
async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "
await anyio.sleep(0)
async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
nonlocal filled_by_bg_task
async for thing in numbers(start, stop):
filled_by_bg_task = filled_by_bg_task + thing
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = FileResponse(path=path, filename="example.png", background=cleanup_task)
await response(scope, receive, send)
assert filled_by_bg_task == ""
client = test_client_factory(app)
response = client.get("/")
expected_disposition = 'attachment; filename="example.png"'
assert response.status_code == status.HTTP_200_OK
assert response.content == content
assert response.headers["content-type"] == "image/png"
assert response.headers["content-disposition"] == expected_disposition
assert "content-length" in response.headers
assert "last-modified" in response.headers
assert "etag" in response.headers
assert filled_by_bg_task == "6, 7, 8, 9"
@pytest.mark.anyio
async def test_file_response_on_head_method(tmp_path: Path) -> None:
path = tmp_path / "xyz"
content = b"" * 1000
path.write_bytes(content)
app = FileResponse(path=path, filename="example.png")
async def receive() -> Message: # type: ignore[empty-body]
... # pragma: no cover
async def send(message: Message) -> None:
if message["type"] == "http.response.start":
assert message["status"] == status.HTTP_200_OK
headers = Headers(raw=message["headers"])
assert headers["content-type"] == "image/png"
assert "content-length" in headers
assert "content-disposition" in headers
assert "last-modified" in headers
assert "etag" in headers
elif message["type"] == "http.response.body": # pragma: no branch
assert message["body"] == b""
assert message["more_body"] is False
# Since the TestClient drops the response body on HEAD requests, we need to test
# this directly.
await app({"type": "http", "method": "head", "headers": [(b"key", b"value")]}, receive, send)
def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
path.write_bytes(b"")
# By default, FileResponse will determine the `content-type` based on
# the filename or path, unless a specific `media_type` is provided.
app = FileResponse(path=path, filename="example.png", media_type="image/jpeg")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-type"] == "image/jpeg"
def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
app = FileResponse(path=tmp_path, filename="example.png")
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/")
assert "is not a file" in str(exc_info.value)
def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "404.txt"
app = FileResponse(path=path, filename="404.txt")
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/")
assert "does not exist" in str(exc_info.value)
def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "你好.txt" # probably "Hello.txt" in Chinese
path = tmp_path / filename
path.write_bytes(content)
app = FileResponse(path=path, filename=filename)
client = test_client_factory(app)
response = client.get("/")
expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt"
assert response.status_code == status.HTTP_200_OK
assert response.content == content
assert response.headers["content-disposition"] == expected_disposition
def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "hello.txt"
path = tmp_path / filename
path.write_bytes(content)
app = FileResponse(path=path, filename=filename, content_disposition_type="inline")
client = test_client_factory(app)
response = client.get("/")
expected_disposition = 'inline; filename="hello.txt"'
assert response.status_code == status.HTTP_200_OK
assert response.content == content
assert response.headers["content-disposition"] == expected_disposition
def test_file_response_with_range_header(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "hello.txt"
path = tmp_path / filename
path.write_bytes(content)
etag = '"a_non_autogenerated_etag"'
app = FileResponse(path=path, filename=filename, headers={"etag": etag})
client = test_client_factory(app)
response = client.get("/", headers={"range": "bytes=0-4", "if-range": etag})
assert response.status_code == status.HTTP_206_PARTIAL_CONTENT
assert response.content == content[:5]
assert response.headers["etag"] == etag
assert response.headers["content-length"] == "5"
assert response.headers["content-range"] == f"bytes 0-4/{len(content)}"
@pytest.mark.anyio
async def test_file_response_with_pathsend(tmpdir: Path) -> None:
path = tmpdir / "xyz"
content = b"" * 1000
with open(path, "wb") as file:
file.write(content)
app = FileResponse(path=path, filename="example.png")
async def receive() -> Message: # type: ignore[empty-body]
... # pragma: no cover
async def send(message: Message) -> None:
if message["type"] == "http.response.start":
assert message["status"] == status.HTTP_200_OK
headers = Headers(raw=message["headers"])
assert headers["content-type"] == "image/png"
assert "content-length" in headers
assert "content-disposition" in headers
assert "last-modified" in headers
assert "etag" in headers
elif message["type"] == "http.response.pathsend": # pragma: no branch
assert message["path"] == str(path)
# Since the TestClient doesn't support `pathsend`, we need to test this directly.
await app(
{"type": "http", "method": "get", "headers": [], "extensions": {"http.response.pathsend": {}}},
receive,
send,
)
def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie(
"mycookie",
"myvalue",
max_age=10,
expires=10,
path="/",
domain="localhost",
secure=True,
httponly=True,
samesite="none",
partitioned=True if sys.version_info >= (3, 14) else False,
)
await response(scope, receive, send)
partitioned_text = "Partitioned; " if sys.version_info >= (3, 14) else ""
client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
assert (
response.headers["set-cookie"] == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; "
f"HttpOnly; Max-Age=10; {partitioned_text}Path=/; SameSite=none; Secure"
)
@pytest.mark.skipif(sys.version_info >= (3, 14), reason="Only relevant for <3.14")
def test_set_cookie_raises_for_invalid_python_version(
test_client_factory: TestClientFactory,
) -> None: # pragma: no cover
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
with pytest.raises(ValueError):
response.set_cookie("mycookie", "myvalue", partitioned=True)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
assert response.headers.get("set-cookie") is None
def test_set_cookie_path_none(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie("mycookie", "myvalue", path=None)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
assert response.headers["set-cookie"] == "mycookie=myvalue; SameSite=lax"
def test_set_cookie_samesite_none(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie("mycookie", "myvalue", samesite=None)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
assert response.headers["set-cookie"] == "mycookie=myvalue; Path=/"
@pytest.mark.parametrize(
"expires",
[
pytest.param(dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"),
pytest.param("Thu, 22 Jan 2037 12:00:10 GMT", id="str"),
pytest.param(10, id="int"),
],
)
def test_expires_on_set_cookie(
test_client_factory: TestClientFactory,
monkeypatch: pytest.MonkeyPatch,
expires: str,
) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie("mycookie", "myvalue", expires=expires)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
cookie = SimpleCookie(response.headers.get("set-cookie"))
assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT"
def test_delete_cookie(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = Response("Hello, world!", media_type="text/plain")
if request.cookies.get("mycookie"):
response.delete_cookie("mycookie")
else:
response.set_cookie("mycookie", "myvalue")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.cookies["mycookie"]
response = client.get("/")
assert not response.cookies.get("mycookie")
def test_populate_headers(test_client_factory: TestClientFactory) -> None:
app = Response(content="hi", headers={}, media_type="text/html")
client = test_client_factory(app)
response = client.get("/")
assert response.text == "hi"
assert response.headers["content-length"] == "2"
assert response.headers["content-type"] == "text/html; charset=utf-8"
def test_head_method(test_client_factory: TestClientFactory) -> None:
app = Response("hello, world", media_type="text/plain")
client = test_client_factory(app)
response = client.head("/")
assert response.text == ""
def test_empty_response(test_client_factory: TestClientFactory) -> None:
app = Response()
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.content == b""
assert response.headers["content-length"] == "0"
assert "content-type" not in response.headers
def test_empty_204_response(test_client_factory: TestClientFactory) -> None:
app = Response(status_code=204)
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-length" not in response.headers
def test_non_empty_response(test_client_factory: TestClientFactory) -> None:
app = Response(content="hi")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "2"
def test_response_do_not_add_redundant_charset(
test_client_factory: TestClientFactory,
) -> None:
app = Response(media_type="text/plain; charset=utf-8")
client = test_client_factory(app)
response = client.get("/")
assert response.headers["content-type"] == "text/plain; charset=utf-8"
def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
content = b"" * 1000
path.write_bytes(content)
app = FileResponse(path=path, filename="example.png")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == str(len(content))
def test_streaming_response_unknown_size(
test_client_factory: TestClientFactory,
) -> None:
app = StreamingResponse(content=iter(["hello", "world"]))
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-length" not in response.headers
def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None:
app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"})
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "10"
def test_response_memoryview(test_client_factory: TestClientFactory) -> None:
app = Response(content=memoryview(b"\xc0"))
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.content == b"\xc0"
def test_streaming_response_memoryview(test_client_factory: TestClientFactory) -> None:
app = StreamingResponse(content=iter([memoryview(b"\xc0"), memoryview(b"\xf5")]))
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.content == b"\xc0\xf5"
@pytest.mark.anyio
async def test_streaming_response_stops_if_receiving_http_disconnect() -> None:
streamed = 0
disconnected = anyio.Event()
async def receive_disconnect() -> Message:
await disconnected.wait()
return {"type": "http.disconnect"}
async def send(message: Message) -> None:
nonlocal streamed
if message["type"] == "http.response.body":
streamed += len(message.get("body", b""))
# Simulate disconnection after download has started
if streamed >= 16:
disconnected.set()
async def stream_indefinitely() -> AsyncIterator[bytes]:
while True:
# Need a sleep for the event loop to switch to another task
await anyio.sleep(0)
yield b"chunk "
response = StreamingResponse(content=stream_indefinitely())
with anyio.move_on_after(1) as cancel_scope:
await response({"type": "http"}, receive_disconnect, send)
assert not cancel_scope.cancel_called, "Content streaming should stop itself."
@pytest.mark.anyio
async def test_streaming_response_on_client_disconnects() -> None:
chunks = bytearray()
streamed = False
async def receive_disconnect() -> Message:
raise NotImplementedError
async def send(message: Message) -> None:
nonlocal streamed
if message["type"] == "http.response.body":
if not streamed:
chunks.extend(message.get("body", b""))
streamed = True
else:
raise OSError
async def stream_indefinitely() -> AsyncGenerator[bytes, None]:
while True:
await anyio.sleep(0)
yield b"chunk"
stream = stream_indefinitely()
response = StreamingResponse(content=stream)
with anyio.move_on_after(1) as cancel_scope:
with pytest.raises(ClientDisconnect):
await response({"type": "http", "asgi": {"spec_version": "2.4"}}, receive_disconnect, send)
assert not cancel_scope.cancel_called, "Content streaming should stop itself."
assert chunks == b"chunk"
await stream.aclose()
@pytest.mark.anyio
async def test_streaming_response_runs_background_on_websocket_scope() -> None:
background_called = False
sent: list[Message] = []
async def receive() -> Message:
return {} # pragma: no cover
async def send(message: Message) -> None:
sent.append(message)
def run_background() -> None:
nonlocal background_called
background_called = True
async def stream() -> AsyncIterator[bytes]:
yield b"chunk"
response = StreamingResponse(stream(), background=BackgroundTask(run_background))
await response({"type": "websocket"}, receive, send)
assert background_called
assert [message["type"] for message in sent] == [
"websocket.http.response.start",
"websocket.http.response.body",
"websocket.http.response.body",
]
README = """\
# BáiZé
Powerful and exquisite WSGI/ASGI framework/toolkit.
The minimize implementation of methods required in the Web framework. No redundant implementation means that you can freely customize functions without considering the conflict with baize's own implementation.
Under the ASGI/WSGI protocol, the interface of the request object and the response object is almost the same, only need to add or delete `await` in the appropriate place. In addition, it should be noted that ASGI supports WebSocket but WSGI does not.
""" # noqa: E501
@pytest.fixture
def readme_file(tmp_path: Path) -> Path:
filepath = tmp_path / "README.txt"
filepath.write_bytes(README.encode("utf8"))
return filepath
@pytest.fixture
def file_response_client(readme_file: Path, test_client_factory: TestClientFactory) -> TestClient:
return test_client_factory(app=FileResponse(str(readme_file)))
def test_file_response_without_range(file_response_client: TestClient) -> None:
response = file_response_client.get("/")
assert response.status_code == 200
assert "content-range" not in response.headers
assert response.headers["content-length"] == str(len(README.encode("utf8")))
assert response.headers["content-type"] == "text/plain; charset=utf-8"
assert response.text == README
def test_file_response_head(file_response_client: TestClient) -> None:
response = file_response_client.head("/")
assert response.status_code == 200
assert "content-range" not in response.headers
assert response.headers["content-length"] == str(len(README.encode("utf8")))
assert response.headers["content-type"] == "text/plain; charset=utf-8"
assert response.content == b""
def test_file_response_range(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes=0-100"})
assert response.status_code == 206
assert response.headers["content-range"] == f"bytes 0-100/{len(README.encode('utf8'))}"
assert response.headers["content-length"] == "101"
assert response.headers["content-type"] == "text/plain; charset=utf-8"
assert response.content == README.encode("utf8")[:101]
def test_file_response_range_head(file_response_client: TestClient) -> None:
response = file_response_client.head("/", headers={"Range": "bytes=0-100"})
assert response.status_code == 206
assert response.headers["content-range"] == f"bytes 0-100/{len(README.encode('utf8'))}"
assert response.headers["content-length"] == str(101)
assert response.headers["content-type"] == "text/plain; charset=utf-8"
assert response.content == b""
def test_file_response_range_multi(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes=0-100, 200-300"})
assert response.status_code == 206
assert "content-range" not in response.headers
assert response.headers["content-length"] == "448"
assert response.headers["content-type"].startswith("multipart/byteranges; boundary=")
def test_file_response_range_multi_head(file_response_client: TestClient) -> None:
response = file_response_client.head("/", headers={"Range": "bytes=0-100, 200-300"})
assert response.status_code == 206
assert "content-range" not in response.headers
assert response.headers["content-length"] == "448"
assert response.headers["content-type"].startswith("multipart/byteranges; boundary=")
assert response.content == b""
response = file_response_client.head(
"/",
headers={"Range": "bytes=200-300", "if-range": response.headers["etag"][:-1]},
)
assert response.status_code == 200
response = file_response_client.head(
"/",
headers={"Range": "bytes=200-300", "if-range": response.headers["etag"]},
)
assert response.status_code == 206
def test_file_response_range_invalid(file_response_client: TestClient) -> None:
response = file_response_client.head("/", headers={"Range": "bytes: 0-1000"})
assert response.status_code == 400
def test_file_response_range_head_max(file_response_client: TestClient) -> None:
response = file_response_client.head("/", headers={"Range": f"bytes=0-{len(README.encode('utf8')) + 1}"})
assert response.status_code == 206
def test_file_response_range_416(file_response_client: TestClient) -> None:
response = file_response_client.head("/", headers={"Range": f"bytes={len(README.encode('utf8')) + 1}-"})
assert response.status_code == 416
assert response.headers["Content-Range"] == f"bytes */{len(README.encode('utf8'))}"
def test_file_response_only_support_bytes_range(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "items=0-100"})
assert response.status_code == 400
assert response.text == "Only support bytes range"
def test_file_response_range_must_be_requested(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes="})
assert response.status_code == 400
assert response.text == "Range header: range must be requested"
def test_file_response_start_must_be_less_than_end(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes=100-0"})
assert response.status_code == 400
assert response.text == "Range header: start must be less than end"
def test_file_response_merge_ranges(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes=0-100, 50-200"})
assert response.status_code == 206
assert response.headers["content-length"] == "201"
assert response.headers["content-range"] == f"bytes 0-200/{len(README.encode('utf8'))}"
@dataclass
class MultipartPart:
headers: dict[bytes, bytes]
data: bytes
def parse_multipart_data(data: bytes, boundary: bytes | str) -> list[MultipartPart]:
parts: list[MultipartPart] = []
done = False
current_headers: dict[bytes, bytes] = {}
current_header_field: bytes = b""
def on_part_begin() -> None:
nonlocal current_headers
current_headers = {}
def on_part_data(data: bytes, start: int, end: int) -> None:
parts.append(MultipartPart(current_headers, data[start:end]))
def on_header_field(data: bytes, start: int, end: int) -> None:
nonlocal current_header_field
current_header_field = data[start:end]
def on_header_value(data: bytes, start: int, end: int) -> None:
current_headers[current_header_field] = data[start:end]
def on_end() -> None:
nonlocal done
done = True
parser = MultipartParser(
boundary,
dict(
on_part_begin=on_part_begin,
on_part_data=on_part_data,
on_header_field=on_header_field,
on_header_value=on_header_value,
on_end=on_end,
),
)
parser.write(data)
parser.finalize()
assert done
return parts
def test_file_response_insert_ranges(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes=100-200, 0-50"})
assert response.status_code == 206
assert "content-range" not in response.headers
assert response.headers["content-type"].startswith("multipart/byteranges; boundary=")
boundary = response.headers["content-type"].split("boundary=")[1]
assert response.text.splitlines() == [
f"--{boundary}",
"Content-Type: text/plain; charset=utf-8",
"Content-Range: bytes 0-50/526",
"",
"# BáiZé",
"",
"Powerful and exquisite WSGI/ASGI framewo",
f"--{boundary}",
"Content-Type: text/plain; charset=utf-8",
"Content-Range: bytes 100-200/526",
"",
"ds required in the Web framework. No redundant implementation means that you can freely customize fun",
f"--{boundary}--",
]
parts = parse_multipart_data(response._content, boundary)
assert all(
value == b"text/plain; charset=utf-8"
for part in parts
for key, value in part.headers.items()
if key == b"Content-Type"
)
assert len(parts) == 2
assert parts[0].headers[b"Content-Range"] == b"bytes 0-50/526"
assert parts[0].data == "# BáiZé\n\nPowerful and exquisite WSGI/ASGI framewo".encode()
assert parts[1].headers[b"Content-Range"] == b"bytes 100-200/526"
assert (
parts[1].data
== b"ds required in the Web framework. No redundant implementation means that you can freely customize fun"
)
def test_file_response_range_without_dash(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes=100, 0-50"})
assert response.status_code == 206
assert response.headers["content-range"] == f"bytes 0-50/{len(README.encode('utf8'))}"
def test_file_response_range_empty_start_and_end(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes= - , 0-50"})
assert response.status_code == 206
assert response.headers["content-range"] == f"bytes 0-50/{len(README.encode('utf8'))}"
def test_file_response_range_ignore_non_numeric(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes=abc-def, 0-50"})
assert response.status_code == 206
assert response.headers["content-range"] == f"bytes 0-50/{len(README.encode('utf8'))}"
def test_file_response_suffix_range(file_response_client: TestClient) -> None:
# Test suffix range (last N bytes) - line 523 with empty start_str
response = file_response_client.get("/", headers={"Range": "bytes=-100"})
assert response.status_code == 206
file_size = len(README.encode("utf8"))
assert response.headers["content-range"] == f"bytes {file_size - 100}-{file_size - 1}/{file_size}"
assert response.headers["content-length"] == "100"
assert response.content == README.encode("utf8")[-100:]
def test_file_response_multiple_calls(file_response_client: TestClient) -> None:
response = file_response_client.get("/", headers={"Range": "bytes=0-100"})
assert response.status_code == 206
response = file_response_client.get("/")
assert response.status_code == 200
assert "content-range" not in response.headers
assert response.headers["content-length"] == str(len(README.encode("utf8")))
assert response.headers["content-type"] == "text/plain; charset=utf-8"
@pytest.mark.anyio
async def test_file_response_multi_small_chunk_size(readme_file: Path) -> None:
class SmallChunkSizeFileResponse(FileResponse):
chunk_size = 10
app = SmallChunkSizeFileResponse(path=str(readme_file))
received_chunks: list[bytes] = []
start_message: dict[str, Any] = {}
async def receive() -> Message:
raise NotImplementedError("Should not be called!")
async def send(message: Message) -> None:
if message["type"] == "http.response.start":
start_message.update(message)
elif message["type"] == "http.response.body": # pragma: no branch
received_chunks.append(message["body"])
await app({"type": "http", "method": "get", "headers": [(b"range", b"bytes=0-15,20-35,35-50")]}, receive, send)
assert start_message["status"] == 206
headers = Headers(raw=start_message["headers"])
assert "content-range" not in headers
assert headers.get("accept-ranges") == "bytes"
assert "content-length" in headers
assert "last-modified" in headers
assert "etag" in headers
assert headers["content-type"].startswith("multipart/byteranges; boundary=")
boundary = headers["content-type"].split("boundary=")[1]
assert received_chunks == [
# Send the part headers.
f"--{boundary}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Range: bytes 0-15/526\r\n\r\n".encode(),
# Send the first chunk (10 bytes).
b"# B\xc3\xa1iZ\xc3\xa9\n",
# Send the second chunk (6 bytes).
b"\nPower",
# Send the new line to separate the parts.
b"\r\n",
# Send the part headers. We merge the ranges 20-35 and 35-50 into a single part.
f"--{boundary}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Range: bytes 20-50/526\r\n\r\n".encode(),
# Send the first chunk (10 bytes).
b"and exquis",
# Send the second chunk (10 bytes).
b"ite WSGI/A",
# Send the third chunk (10 bytes).
b"SGI framew",
# Send the last chunk (1 byte).
b"o",
b"\r\n",
f"--{boundary}--".encode(),
]
================================================
FILE: tests/test_routing.py
================================================
from __future__ import annotations
import contextlib
import functools
import json
import uuid
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator
from typing import TypedDict
import pytest
from typing_extensions import Never
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect
from tests.types import TestClientFactory
def homepage(request: Request) -> Response:
return Response("Hello, world", media_type="text/plain")
def users(request: Request) -> Response:
return Response("All users", media_type="text/plain")
def user(request: Request) -> Response:
content = "User " + request.path_params["username"]
return Response(content, media_type="text/plain")
def user_me(request: Request) -> Response:
content = "User fixed me"
return Response(content, media_type="text/plain")
def disable_user(request: Request) -> Response:
content = "User " + request.path_params["username"] + " disabled"
return Response(content, media_type="text/plain")
def user_no_match(request: Request) -> Response: # pragma: no cover
content = "User fixed no match"
return Response(content, media_type="text/plain")
async def partial_endpoint(arg: str, request: Request) -> JSONResponse:
return JSONResponse({"arg": arg})
async def partial_ws_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()
class PartialRoutes:
@classmethod
async def async_endpoint(cls, arg: str, request: Request) -> JSONResponse:
return JSONResponse({"arg": arg})
@classmethod
async def async_ws_endpoint(cls, websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()
def func_homepage(request: Request) -> Response:
return Response("Hello, world!", media_type="text/plain")
def contact(request: Request) -> Response:
return Response("Hello, POST!", media_type="text/plain")
def int_convertor(request: Request) -> JSONResponse:
number = request.path_params["param"]
return JSONResponse({"int": number})
def float_convertor(request: Request) -> JSONResponse:
num = request.path_params["param"]
return JSONResponse({"float": num})
def path_convertor(request: Request) -> JSONResponse:
path = request.path_params["param"]
return JSONResponse({"path": path})
def uuid_converter(request: Request) -> JSONResponse:
uuid_param = request.path_params["param"]
return JSONResponse({"uuid": str(uuid_param)})
def path_with_parentheses(request: Request) -> JSONResponse:
number = request.path_params["param"]
return JSONResponse({"int": number})
async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
async def websocket_params(session: WebSocket) -> None:
await session.accept()
await session.send_text(f"Hello, {session.path_params['room']}!")
await session.close()
app = Router(
[
Route("/", endpoint=homepage, methods=["GET"]),
Mount(
"/users",
routes=[
Route("/", endpoint=users),
Route("/me", endpoint=user_me),
Route("/{username}", endpoint=user),
Route("/{username}:disable", endpoint=disable_user, methods=["PUT"]),
Route("/nomatch", endpoint=user_no_match),
],
),
Mount(
"/partial",
routes=[
Route("/", endpoint=functools.partial(partial_endpoint, "foo")),
Route(
"/cls",
endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"),
),
WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)),
WebSocketRoute(
"/ws/cls",
endpoint=functools.partial(PartialRoutes.async_ws_endpoint),
),
],
),
Mount("/static", app=Response("xxxxx", media_type="image/png")),
Route("/func", endpoint=func_homepage, methods=["GET"]),
Route("/func", endpoint=contact, methods=["POST"]),
Route("/int/{param:int}", endpoint=int_convertor, name="int-convertor"),
Route("/float/{param:float}", endpoint=float_convertor, name="float-convertor"),
Route("/path/{param:path}", endpoint=path_convertor, name="path-convertor"),
Route("/uuid/{param:uuid}", endpoint=uuid_converter, name="uuid-convertor"),
# Route with chars that conflict with regex meta chars
Route(
"/path-with-parentheses({param:int})",
endpoint=path_with_parentheses,
name="path-with-parentheses",
),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws/{room}", endpoint=websocket_params),
]
)
@pytest.fixture
def client(
test_client_factory: TestClientFactory,
) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client
@pytest.mark.filterwarnings(
r"ignore"
r":Trying to detect encoding from a tiny portion of \(5\) byte\(s\)\."
r":UserWarning"
r":charset_normalizer.api"
)
def test_router(client: TestClient) -> None:
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world"
response = client.post("/")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
assert set(response.headers["allow"].split(", ")) == {"HEAD", "GET"}
response = client.get("/foo")
assert response.status_code == 404
assert response.text == "Not Found"
response = client.get("/users")
assert response.status_code == 200
assert response.text == "All users"
response = client.get("/users/tomchristie")
assert response.status_code == 200
assert response.text == "User tomchristie"
response = client.get("/users/me")
assert response.status_code == 200
assert response.text == "User fixed me"
response = client.get("/users/tomchristie/")
assert response.status_code == 200
assert response.url == "http://testserver/users/tomchristie"
assert response.text == "User tomchristie"
response = client.put("/users/tomchristie:disable")
assert response.status_code == 200
assert response.url == "http://testserver/users/tomchristie:disable"
assert response.text == "User tomchristie disabled"
response = client.get("/users/nomatch")
assert response.status_code == 200
assert response.text == "User nomatch"
response = client.get("/static/123")
assert response.status_code == 200
assert response.text == "xxxxx"
def test_route_converters(client: TestClient) -> None:
# Test integer conversion
response = client.get("/int/5")
assert response.status_code == 200
assert response.json() == {"int": 5}
assert app.url_path_for("int-convertor", param=5) == "/int/5"
# Test path with parentheses
response = client.get("/path-with-parentheses(7)")
assert response.status_code == 200
assert response.json() == {"int": 7}
assert app.url_path_for("path-with-parentheses", param=7) == "/path-with-parentheses(7)"
# Test float conversion
response = client.get("/float/25.5")
assert response.status_code == 200
assert response.json() == {"float": 25.5}
assert app.url_path_for("float-convertor", param=25.5) == "/float/25.5"
# Test path conversion
response = client.get("/path/some/example")
assert response.status_code == 200
assert response.json() == {"path": "some/example"}
assert app.url_path_for("path-convertor", param="some/example") == "/path/some/example"
# Test UUID conversion
response = client.get("/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")
assert response.status_code == 200
assert response.json() == {"uuid": "ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"}
assert (
app.url_path_for("uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"))
== "/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"
)
def test_url_path_for() -> None:
assert app.url_path_for("homepage") == "/"
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_path_for("websocket_endpoint") == "/ws"
with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "".'):
assert app.url_path_for("broken")
with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "key, key2".'):
assert app.url_path_for("broken", key="value", key2="value2")
with pytest.raises(AssertionError):
app.url_path_for("user", username="tom/christie")
with pytest.raises(AssertionError):
app.url_path_for("user", username="")
def test_url_for() -> None:
assert app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/"
assert (
app.url_path_for("homepage").make_absolute_url(base_url="https://example.org/root_path/")
== "https://example.org/root_path/"
)
assert (
app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org")
== "https://example.org/users/tomchristie"
)
assert (
app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org/root_path/")
== "https://example.org/root_path/users/tomchristie"
)
assert (
app.url_path_for("websocket_endpoint").make_absolute_url(base_url="https://example.org")
== "wss://example.org/ws"
)
def test_router_add_route(client: TestClient) -> None:
response = client.get("/func")
assert response.status_code == 200
assert response.text == "Hello, world!"
def test_router_duplicate_path(client: TestClient) -> None:
response = client.post("/func")
assert response.status_code == 200
assert response.text == "Hello, POST!"
def test_router_add_websocket_route(client: TestClient) -> None:
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
with client.websocket_connect("/ws/test") as session:
text = session.receive_text()
assert text == "Hello, test!"
def test_router_middleware(test_client_factory: TestClientFactory) -> None:
class CustomMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("OK")
await response(scope, receive, send)
app = Router(
routes=[Route("/", homepage)],
middleware=[Middleware(CustomMiddleware)],
)
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "OK"
def http_endpoint(request: Request) -> Response:
url = request.url_for("http_endpoint")
return Response(f"URL: {url}", media_type="text/plain")
class WebSocketEndpoint:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope=scope, receive=receive, send=send)
await websocket.accept()
await websocket.send_json({"URL": str(websocket.url_for("websocket_endpoint"))})
await websocket.close()
mixed_protocol_app = Router(
routes=[
Route("/", endpoint=http_endpoint),
WebSocketRoute("/", endpoint=WebSocketEndpoint(), name="websocket_endpoint"),
]
)
def test_protocol_switch(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(mixed_protocol_app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "URL: http://testserver/"
with client.websocket_connect("/") as session:
assert session.receive_json() == {"URL": "ws://testserver/"}
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/404"):
pass # pragma: no cover
ok = PlainTextResponse("OK")
def test_mount_urls(test_client_factory: TestClientFactory) -> None:
mounted = Router([Mount("/users", ok, name="users")])
client = test_client_factory(mounted)
assert client.get("/users").status_code == 200
assert client.get("/users").url == "http://testserver/users/"
assert client.get("/users/").status_code == 200
assert client.get("/users/a").status_code == 200
assert client.get("/usersa").status_code == 404
def test_reverse_mount_urls() -> None:
mounted = Router([Mount("/users", ok, name="users")])
assert mounted.url_path_for("users", path="/a") == "/users/a"
users = Router([Route("/{username}", ok, name="user")])
mounted = Router([Mount("/{subpath}/users", users, name="users")])
assert mounted.url_path_for("users:user", subpath="test", username="tom") == "/test/users/tom"
assert mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
mounted = Router([Mount("/users", ok, name="users")])
with pytest.raises(NoMatchFound):
mounted.url_path_for("users", path="/a", foo="bar")
mounted = Router([Mount("/users", ok, name="users")])
with pytest.raises(NoMatchFound):
mounted.url_path_for("users")
def test_mount_at_root(test_client_factory: TestClientFactory) -> None:
mounted = Router([Mount("/", ok, name="users")])
client = test_client_factory(mounted)
assert client.get("/").status_code == 200
def users_api(request: Request) -> JSONResponse:
return JSONResponse({"users": [{"username": "tom"}]})
mixed_hosts_app = Router(
routes=[
Host(
"www.example.org",
app=Router(
[
Route("/", homepage, name="homepage"),
Route("/users", users, name="users"),
]
),
),
Host(
"api.example.org",
name="api",
app=Router([Route("/users", users_api, name="users")]),
),
Host(
"port.example.org:3600",
name="port",
app=Router([Route("/", homepage, name="homepage")]),
),
]
)
def test_host_routing(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/")
response = client.get("/users")
assert response.status_code == 200
assert response.json() == {"users": [{"username": "tom"}]}
response = client.get("/")
assert response.status_code == 404
client = test_client_factory(mixed_hosts_app, base_url="https://www.example.org/")
response = client.get("/users")
assert response.status_code == 200
assert response.text == "All users"
response = client.get("/")
assert response.status_code == 200
client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:3600/")
response = client.get("/users")
assert response.status_code == 404
response = client.get("/")
assert response.status_code == 200
# Port in requested Host is irrelevant.
client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/")
response = client.get("/")
assert response.status_code == 200
client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:5600/")
response = client.get("/")
assert response.status_code == 200
def test_host_reverse_urls() -> None:
assert mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") == "https://www.example.org/"
assert (
mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") == "https://www.example.org/users"
)
assert (
mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever")
== "https://api.example.org/users"
)
assert (
mixed_hosts_app.url_path_for("port:homepage").make_absolute_url("https://whatever")
== "https://port.example.org:3600/"
)
with pytest.raises(NoMatchFound):
mixed_hosts_app.url_path_for("api", path="whatever", foo="bar")
async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None:
response = JSONResponse({"subdomain": scope["path_params"]["subdomain"]})
await response(scope, receive, send)
subdomain_router = Router(routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")])
def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(subdomain_router, base_url="https://foo.example.org/")
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"subdomain": "foo"}
def test_subdomain_reverse_urls() -> None:
assert (
subdomain_router.url_path_for("subdomains", subdomain="foo", path="/homepage").make_absolute_url(
"https://whatever"
)
== "https://foo.example.org/homepage"
)
async def echo_urls(request: Request) -> JSONResponse:
return JSONResponse(
{
"index": str(request.url_for("index")),
"submount": str(request.url_for("mount:submount")),
}
)
echo_url_routes = [
Route("/", echo_urls, name="index", methods=["GET"]),
Mount(
"/submount",
name="mount",
routes=[Route("/", echo_urls, name="submount", methods=["GET"])],
),
]
def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None:
app = Starlette(routes=echo_url_routes)
client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path")
response = client.get("/sub_path/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
"submount": "https://www.example.org/sub_path/submount/",
}
response = client.get("/sub_path/submount/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
"submount": "https://www.example.org/sub_path/submount/",
}
async def stub_app(scope: Scope, receive: Receive, send: Send) -> None:
pass # pragma: no cover
double_mount_routes = [
Mount("/mount", name="mount", routes=[Mount("/static", stub_app, name="static")]),
]
def test_url_for_with_double_mount() -> None:
app = Starlette(routes=double_mount_routes)
url = app.url_path_for("mount:static", path="123")
assert url == "/mount/static/123"
def test_url_for_with_root_path_ending_with_slash(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> JSONResponse:
return JSONResponse({"index": str(request.url_for("homepage"))})
app = Starlette(routes=[Route("/", homepage, name="homepage")])
client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path/")
response = client.get("/sub_path/")
assert response.json() == {"index": "https://www.example.org/sub_path/"}
def test_standalone_route_matches(
test_client_factory: TestClientFactory,
) -> None:
app = Route("/", PlainTextResponse("Hello, World!"))
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, World!"
def test_standalone_route_does_not_match(
test_client_factory: Callable[..., TestClient],
) -> None:
app = Route("/", PlainTextResponse("Hello, World!"))
client = test_client_factory(app)
response = client.get("/invalid")
assert response.status_code == 404
assert response.text == "Not Found"
async def ws_helloworld(websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_text("Hello, world!")
await websocket.close()
def test_standalone_ws_route_matches(
test_client_factory: TestClientFactory,
) -> None:
app = WebSocketRoute("/", ws_helloworld)
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
text = websocket.receive_text()
assert text == "Hello, world!"
def test_standalone_ws_route_does_not_match(
test_client_factory: TestClientFactory,
) -> None:
app = WebSocketRoute("/", ws_helloworld)
client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/invalid"):
pass # pragma: no cover
def test_lifespan_state_unsupported(test_client_factory: TestClientFactory) -> None:
@contextlib.asynccontextmanager
async def lifespan(app: ASGIApp) -> AsyncGenerator[dict[str, str], None]:
yield {"foo": "bar"}
app = Router(
lifespan=lifespan,
routes=[Mount("/", PlainTextResponse("hello, world"))],
)
async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None:
del scope["state"]
await app(scope, receive, send)
with pytest.raises(RuntimeError, match='The server does not support "state" in the lifespan scope'):
with test_client_factory(no_state_wrapper):
raise AssertionError("Should not be called") # pragma: no cover
def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None:
startup_complete = False
shutdown_complete = False
class State(TypedDict):
count: int
items: list[int]
async def hello_world(request: Request) -> Response:
# modifications to the state should not leak across requests
assert request.state.count == 0
# modify the state, this should not leak to the lifespan or other requests
request.state.count += 1
# since state.items is a mutable object this modification _will_ leak across
# requests and to the lifespan
request.state.items.append(1)
return PlainTextResponse("hello, world")
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[State]:
nonlocal startup_complete, shutdown_complete
startup_complete = True
state = State(count=0, items=[])
yield state
shutdown_complete = True
# modifications made to the state from a request do not leak to the lifespan
assert state["count"] == 0
# unless of course the request mutates a mutable object that is referenced
# via state
assert state["items"] == [1, 1]
app = Router(
lifespan=lifespan,
routes=[Route("/", hello_world)],
)
assert not startup_complete
assert not shutdown_complete
with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
# Calling it a second time to ensure that the state is preserved.
client.get("/")
assert startup_complete
assert shutdown_complete
def test_raise_on_startup(test_client_factory: TestClientFactory) -> None:
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[Never]:
raise RuntimeError()
yield # pragma: no cover
router = Router(lifespan=lifespan)
startup_failed = False
async def app(scope: Scope, receive: Receive, send: Send) -> None:
async def _send(message: Message) -> None:
nonlocal startup_failed
if message["type"] == "lifespan.startup.failed": # pragma: no branch
startup_failed = True
return await send(message)
await router(scope, receive, _send)
with pytest.raises(RuntimeError):
with test_client_factory(app):
pass # pragma: no cover
assert startup_failed
def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None:
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]:
yield
raise RuntimeError("Shutdown failed")
app = Router(lifespan=lifespan)
with pytest.raises(RuntimeError, match="Shutdown failed"):
with test_client_factory(app):
pass # pragma: no cover
def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None:
test_client = test_client_factory(app)
response = test_client.get("/partial")
assert response.status_code == 200
assert response.json() == {"arg": "foo"}
cls_method_response = test_client.get("/partial/cls")
assert cls_method_response.status_code == 200
assert cls_method_response.json() == {"arg": "foo"}
def test_partial_async_ws_endpoint(
test_client_factory: TestClientFactory,
) -> None:
test_client = test_client_factory(app)
with test_client.websocket_connect("/partial/ws") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws"}
with test_client.websocket_connect("/partial/ws/cls") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws/cls"}
def test_duplicated_param_names() -> None:
with pytest.raises(
ValueError,
match="Duplicated param name id at path /{id}/{id}",
):
Route("/{id}/{id}", user)
with pytest.raises(
ValueError,
match="Duplicated param names id, name at path /{id}/{name}/{id}/{name}",
):
Route("/{id}/{name}/{id}/{name}", user)
class Endpoint:
async def my_method(self, request: Request) -> None: ... # pragma: no cover
@classmethod
async def my_classmethod(cls, request: Request) -> None: ... # pragma: no cover
@staticmethod
async def my_staticmethod(request: Request) -> None: ... # pragma: no cover
def __call__(self, request: Request) -> None: ... # pragma: no cover
@pytest.mark.parametrize(
"endpoint, expected_name",
[
pytest.param(func_homepage, "func_homepage", id="function"),
pytest.param(Endpoint().my_method, "my_method", id="method"),
pytest.param(Endpoint.my_classmethod, "my_classmethod", id="classmethod"),
pytest.param(
Endpoint.my_staticmethod,
"my_staticmethod",
id="staticmethod",
),
pytest.param(Endpoint(), "Endpoint", id="object"),
pytest.param(lambda request: ..., "", id="lambda"), # pragma: no branch
],
)
def test_route_name(endpoint: Callable[..., Response], expected_name: str) -> None:
assert Route(path="/", endpoint=endpoint).name == expected_name
class AddHeadersMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["add_headers_middleware"] = True
async def modified_send(msg: Message) -> None:
if msg["type"] == "http.response.start":
msg["headers"].append((b"X-Test", b"Set by middleware"))
await send(msg)
await self.app(scope, receive, modified_send)
def assert_middleware_header_route(request: Request) -> Response:
assert request.scope["add_headers_middleware"] is True
return Response()
route_with_middleware = Starlette(
routes=[
Route(
"/http",
endpoint=assert_middleware_header_route,
methods=["GET"],
middleware=[Middleware(AddHeadersMiddleware)],
),
Route("/home", homepage),
]
)
mounted_routes_with_middleware = Starlette(
routes=[
Mount(
"/http",
routes=[
Route(
"/",
endpoint=assert_middleware_header_route,
methods=["GET"],
name="route",
),
],
middleware=[Middleware(AddHeadersMiddleware)],
),
Route("/home", homepage),
]
)
mounted_app_with_middleware = Starlette(
routes=[
Mount(
"/http",
app=Route(
"/",
endpoint=assert_middleware_header_route,
methods=["GET"],
name="route",
),
middleware=[Middleware(AddHeadersMiddleware)],
),
Route("/home", homepage),
]
)
@pytest.mark.parametrize(
"app",
[
mounted_routes_with_middleware,
mounted_app_with_middleware,
route_with_middleware,
],
)
def test_base_route_middleware(
test_client_factory: TestClientFactory,
app: Starlette,
) -> None:
test_client = test_client_factory(app)
response = test_client.get("/home")
assert response.status_code == 200
assert "X-Test" not in response.headers
response = test_client.get("/http")
assert response.status_code == 200
assert response.headers["X-Test"] == "Set by middleware"
def test_mount_routes_with_middleware_url_path_for() -> None:
"""Checks that url_path_for still works with mounted routes with Middleware"""
assert mounted_routes_with_middleware.url_path_for("route") == "/http/"
def test_mount_asgi_app_with_middleware_url_path_for() -> None:
"""Mounted ASGI apps do not work with url path for,
middleware does not change this
"""
with pytest.raises(NoMatchFound):
mounted_app_with_middleware.url_path_for("route")
def test_add_route_to_app_after_mount(
test_client_factory: Callable[..., TestClient],
) -> None:
"""Checks that Mount will pick up routes
added to the underlying app after it is mounted
"""
inner_app = Router()
app = Mount("/http", app=inner_app)
inner_app.add_route(
"/inner",
endpoint=homepage,
methods=["GET"],
)
client = test_client_factory(app)
response = client.get("/http/inner")
assert response.status_code == 200
def test_exception_on_mounted_apps(
test_client_factory: TestClientFactory,
) -> None:
def exc(request: Request) -> None:
raise Exception("Exc")
sub_app = Starlette(routes=[Route("/", exc)])
app = Starlette(routes=[Mount("/sub", app=sub_app)])
client = test_client_factory(app)
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"
def test_mounted_middleware_does_not_catch_exception(
test_client_factory: Callable[..., TestClient],
) -> None:
# https://github.com/Kludex/starlette/pull/1649#discussion_r960236107
def exc(request: Request) -> Response:
raise HTTPException(status_code=403, detail="auth")
class NamedMiddleware:
def __init__(self, app: ASGIApp, name: str) -> None:
self.app = app
self.name = name
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def modified_send(msg: Message) -> None:
if msg["type"] == "http.response.start":
msg["headers"].append((f"X-{self.name}".encode(), b"true"))
await send(msg)
await self.app(scope, receive, modified_send)
app = Starlette(
routes=[
Mount(
"/mount",
routes=[
Route("/err", exc),
Route("/home", homepage),
],
middleware=[Middleware(NamedMiddleware, name="Mounted")],
),
Route("/err", exc),
Route("/home", homepage),
],
middleware=[Middleware(NamedMiddleware, name="Outer")],
)
client = test_client_factory(app)
resp = client.get("/home")
assert resp.status_code == 200, resp.content
assert "X-Outer" in resp.headers
resp = client.get("/err")
assert resp.status_code == 403, resp.content
assert "X-Outer" in resp.headers
resp = client.get("/mount/home")
assert resp.status_code == 200, resp.content
assert "X-Mounted" in resp.headers
resp = client.get("/mount/err")
assert resp.status_code == 403, resp.content
assert "X-Mounted" in resp.headers
def test_websocket_route_middleware(
test_client_factory: TestClientFactory,
) -> None:
async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
class WebsocketMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def modified_send(msg: Message) -> None:
if msg["type"] == "websocket.accept":
msg["headers"].append((b"X-Test", b"Set by middleware"))
await send(msg)
await self.app(scope, receive, modified_send)
app = Starlette(
routes=[
WebSocketRoute(
"/ws",
endpoint=websocket_endpoint,
middleware=[Middleware(WebsocketMiddleware)],
)
]
)
client = test_client_factory(app)
with client.websocket_connect("/ws") as websocket:
text = websocket.receive_text()
assert text == "Hello, world!"
assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")]
def test_route_repr() -> None:
route = Route("/welcome", endpoint=homepage)
assert repr(route) == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])"
def test_route_repr_without_methods() -> None:
route = Route("/welcome", endpoint=Endpoint, methods=None)
assert repr(route) == "Route(path='/welcome', name='Endpoint', methods=[])"
def test_websocket_route_repr() -> None:
route = WebSocketRoute("/ws", endpoint=websocket_endpoint)
assert repr(route) == "WebSocketRoute(path='/ws', name='websocket_endpoint')"
def test_mount_repr() -> None:
route = Mount(
"/app",
routes=[
Route("/", endpoint=homepage),
],
)
# test for substring because repr(Router) returns unique object ID
assert repr(route).startswith("Mount(path='/app', name='', app=")
def test_mount_named_repr() -> None:
route = Mount(
"/app",
name="app",
routes=[
Route("/", endpoint=homepage),
],
)
# test for substring because repr(Router) returns unique object ID
assert repr(route).startswith("Mount(path='/app', name='app', app=")
def test_host_repr() -> None:
route = Host(
"example.com",
app=Router(
[
Route("/", endpoint=homepage),
]
),
)
# test for substring because repr(Router) returns unique object ID
assert repr(route).startswith("Host(host='example.com', name='', app=")
def test_host_named_repr() -> None:
route = Host(
"example.com",
name="app",
app=Router(
[
Route("/", endpoint=homepage),
]
),
)
# test for substring because repr(Router) returns unique object ID
assert repr(route).startswith("Host(host='example.com', name='app', app=")
async def echo_paths(request: Request, name: str) -> JSONResponse:
return JSONResponse(
{
"name": name,
"path": request.scope["path"],
"root_path": request.scope["root_path"],
}
)
async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str) -> None:
data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
content = json.dumps(data).encode("utf-8")
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"application/json")],
}
)
await send({"type": "http.response.body", "body": content})
echo_paths_routes = [
Route(
"/path",
functools.partial(echo_paths, name="path"),
name="path",
methods=["GET"],
),
Route(
"/root-queue/path",
functools.partial(echo_paths, name="queue_path"),
name="queue_path",
methods=["POST"],
),
Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")),
Mount(
"/sub",
name="mount",
routes=[
Route(
"/path",
functools.partial(echo_paths, name="subpath"),
name="subpath",
methods=["GET"],
),
],
),
]
def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
app = Starlette(routes=echo_paths_routes)
client = test_client_factory(app, base_url="https://www.example.org/", root_path="/root")
response = client.get("/root/path")
assert response.status_code == 200
assert response.json() == {
"name": "path",
"path": "/root/path",
"root_path": "/root",
}
response = client.get("/root/asgipath/")
assert response.status_code == 200
assert response.json() == {
"name": "asgipath",
"path": "/root/asgipath/",
# Things that mount other ASGI apps, like WSGIMiddleware, would not be aware
# of the prefixed path, and would have their own notion of their own paths,
# so they need to be able to rely on the root_path to know the location they
# are mounted on
"root_path": "/root/asgipath",
}
response = client.get("/root/sub/path")
assert response.status_code == 200
assert response.json() == {
"name": "subpath",
"path": "/root/sub/path",
"root_path": "/root/sub",
}
response = client.post("/root/root-queue/path")
assert response.status_code == 200
assert response.json() == {
"name": "queue_path",
"path": "/root/root-queue/path",
"root_path": "/root",
}
================================================
FILE: tests/test_schemas.py
================================================
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.schemas import SchemaGenerator
from starlette.websockets import WebSocket
from tests.types import TestClientFactory
schemas = SchemaGenerator({"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}})
def ws(session: WebSocket) -> None:
"""ws"""
pass # pragma: no cover
def get_user(request: Request) -> None:
"""
responses:
200:
description: A user.
examples:
{"username": "tom"}
"""
pass # pragma: no cover
def list_users(request: Request) -> None:
"""
responses:
200:
description: A list of users.
examples:
[{"username": "tom"}, {"username": "lucy"}]
"""
pass # pragma: no cover
def create_user(request: Request) -> None:
"""
responses:
200:
description: A user.
examples:
{"username": "tom"}
"""
pass # pragma: no cover
class OrganisationsEndpoint(HTTPEndpoint):
def get(self, request: Request) -> None:
"""
responses:
200:
description: A list of organisations.
examples:
[{"name": "Foo Corp."}, {"name": "Acme Ltd."}]
"""
pass # pragma: no cover
def post(self, request: Request) -> None:
"""
responses:
200:
description: An organisation.
examples:
{"name": "Foo Corp."}
"""
pass # pragma: no cover
def regular_docstring_and_schema(request: Request) -> None:
"""
This a regular docstring example (not included in schema)
---
responses:
200:
description: This is included in the schema.
"""
pass # pragma: no cover
def regular_docstring(request: Request) -> None:
"""
This a regular docstring example (not included in schema)
"""
pass # pragma: no cover
def no_docstring(request: Request) -> None:
pass # pragma: no cover
def subapp_endpoint(request: Request) -> None:
"""
responses:
200:
description: This endpoint is part of a subapp.
"""
pass # pragma: no cover
def schema(request: Request) -> Response:
return schemas.OpenAPIResponse(request=request)
subapp = Starlette(
routes=[
Route("/subapp-endpoint", endpoint=subapp_endpoint),
]
)
app = Starlette(
routes=[
WebSocketRoute("/ws", endpoint=ws),
Route("/users/{id:int}", endpoint=get_user, methods=["GET"]),
Route("/users", endpoint=list_users, methods=["GET", "HEAD"]),
Route("/users", endpoint=create_user, methods=["POST"]),
Route("/orgs", endpoint=OrganisationsEndpoint),
Route("/regular-docstring-and-schema", endpoint=regular_docstring_and_schema),
Route("/regular-docstring", endpoint=regular_docstring),
Route("/no-docstring", endpoint=no_docstring),
Route("/schema", endpoint=schema, methods=["GET"], include_in_schema=False),
Mount("/subapp", subapp),
Host("sub.domain.com", app=Router(routes=[Mount("/subapp2", subapp)])),
]
)
def test_schema_generation() -> None:
schema = schemas.get_schema(routes=app.routes)
assert schema == {
"openapi": "3.0.0",
"info": {"title": "Example API", "version": "1.0"},
"paths": {
"/orgs": {
"get": {
"responses": {
200: {
"description": "A list of organisations.",
"examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}],
}
}
},
"post": {
"responses": {
200: {
"description": "An organisation.",
"examples": {"name": "Foo Corp."},
}
}
},
},
"/regular-docstring-and-schema": {
"get": {"responses": {200: {"description": "This is included in the schema."}}}
},
"/subapp/subapp-endpoint": {
"get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}}
},
"/subapp2/subapp-endpoint": {
"get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}}
},
"/users": {
"get": {
"responses": {
200: {
"description": "A list of users.",
"examples": [{"username": "tom"}, {"username": "lucy"}],
}
}
},
"post": {"responses": {200: {"description": "A user.", "examples": {"username": "tom"}}}},
},
"/users/{id}": {
"get": {
"responses": {
200: {
"description": "A user.",
"examples": {"username": "tom"},
}
}
},
},
},
}
EXPECTED_SCHEMA = """
info:
title: Example API
version: '1.0'
openapi: 3.0.0
paths:
/orgs:
get:
responses:
200:
description: A list of organisations.
examples:
- name: Foo Corp.
- name: Acme Ltd.
post:
responses:
200:
description: An organisation.
examples:
name: Foo Corp.
/regular-docstring-and-schema:
get:
responses:
200:
description: This is included in the schema.
/subapp/subapp-endpoint:
get:
responses:
200:
description: This endpoint is part of a subapp.
/subapp2/subapp-endpoint:
get:
responses:
200:
description: This endpoint is part of a subapp.
/users:
get:
responses:
200:
description: A list of users.
examples:
- username: tom
- username: lucy
post:
responses:
200:
description: A user.
examples:
username: tom
/users/{id}:
get:
responses:
200:
description: A user.
examples:
username: tom
"""
def test_schema_endpoint(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.get("/schema")
assert response.headers["Content-Type"] == "application/vnd.oai.openapi"
assert response.text.strip() == EXPECTED_SCHEMA.strip()
================================================
FILE: tests/test_staticfiles.py
================================================
import os
import stat
import tempfile
import time
from pathlib import Path
from typing import Any
import anyio
import pytest
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
from starlette.staticfiles import StaticFiles
from tests.types import TestClientFactory
def test_staticfiles(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
response = client.get("/example.txt")
assert response.status_code == 200
assert response.text == ""
def test_staticfiles_with_pathlib(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "example.txt"
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmp_path)
client = test_client_factory(app)
response = client.get("/example.txt")
assert response.status_code == 200
assert response.text == ""
def test_staticfiles_head_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
"""
see https://github.com/Kludex/starlette/pull/935
"""
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("x" * 100)
async def does_nothing_middleware(request: Request, call_next: RequestResponseEndpoint) -> Response:
response = await call_next(request)
return response
routes = [Mount("/static", app=StaticFiles(directory=tmpdir), name="static")]
middleware = [Middleware(BaseHTTPMiddleware, dispatch=does_nothing_middleware)]
app = Starlette(routes=routes, middleware=middleware)
client = test_client_factory(app)
response = client.head("/static/example.txt")
assert response.status_code == 200
assert response.headers.get("content-length") == "100"
def test_staticfiles_with_package(test_client_factory: TestClientFactory) -> None:
app = StaticFiles(packages=["tests"])
client = test_client_factory(app)
response = client.get("/example.txt")
assert response.status_code == 200
assert response.text == "123\n"
app = StaticFiles(packages=[("tests", "statics")])
client = test_client_factory(app)
response = client.get("/example.txt")
assert response.status_code == 200
assert response.text == "123\n"
def test_staticfiles_post(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
app = Starlette(routes=routes)
client = test_client_factory(app)
response = client.post("/example.txt")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
def test_staticfiles_with_directory_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
app = Starlette(routes=routes)
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 404
assert response.text == "Not Found"
def test_staticfiles_with_missing_file_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
app = Starlette(routes=routes)
client = test_client_factory(app)
response = client.get("/404.txt")
assert response.status_code == 404
assert response.text == "Not Found"
def test_staticfiles_instantiated_with_missing_directory(tmpdir: Path) -> None:
with pytest.raises(RuntimeError) as exc_info:
path = os.path.join(tmpdir, "no_such_directory")
StaticFiles(directory=path)
assert "does not exist" in str(exc_info.value)
def test_staticfiles_configured_with_missing_directory(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "no_such_directory")
app = StaticFiles(directory=path, check_dir=False)
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/example.txt")
assert "does not exist" in str(exc_info.value)
def test_staticfiles_configured_with_file_instead_of_directory(
tmpdir: Path, test_client_factory: TestClientFactory
) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=path, check_dir=False)
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/example.txt")
assert "is not a directory" in str(exc_info.value)
def test_staticfiles_config_check_occurs_only_once(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
assert not app.config_checked
with pytest.raises(HTTPException):
client.get("/")
assert app.config_checked
with pytest.raises(HTTPException):
client.get("/")
def test_staticfiles_prevents_breaking_out_of_directory(tmpdir: Path) -> None:
directory = os.path.join(tmpdir, "foo")
os.mkdir(directory)
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("outside root dir")
app = StaticFiles(directory=directory)
# We can't test this with 'httpx', so we test the app directly here.
path = app.get_path({"path": "/../example.txt"})
scope = {"method": "GET"}
with pytest.raises(HTTPException) as exc_info:
anyio.run(app.get_response, path, scope)
assert exc_info.value.status_code == 404
assert exc_info.value.detail == "Not Found"
def test_staticfiles_never_read_file_for_head_method(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
response = client.head("/example.txt")
assert response.status_code == 200
assert response.content == b""
assert response.headers["content-length"] == "14"
def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
first_resp = client.get("/example.txt")
assert first_resp.status_code == 200
last_etag = first_resp.headers["etag"]
second_resp = client.get("/example.txt", headers={"if-none-match": last_etag})
assert second_resp.status_code == 304
assert second_resp.content == b""
second_resp = client.get("/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'})
assert second_resp.status_code == 304
assert second_resp.content == b""
def test_staticfiles_200_with_etag_mismatch(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
first_resp = client.get("/example.txt")
assert first_resp.status_code == 200
assert first_resp.headers["etag"] != '"123"'
second_resp = client.get("/example.txt", headers={"if-none-match": '"123"'})
assert second_resp.status_code == 200
assert second_resp.content == b""
def test_staticfiles_200_with_etag_mismatch_and_timestamp_match(
tmpdir: Path, test_client_factory: TestClientFactory
) -> None:
path = tmpdir / "example.txt"
path.write_text("", encoding="utf-8")
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
first_resp = client.get("/example.txt")
assert first_resp.status_code == 200
assert first_resp.headers["etag"] != '"123"'
last_modified = first_resp.headers["last-modified"]
# If `if-none-match` is present, `if-modified-since` is ignored.
second_resp = client.get("/example.txt", headers={"if-none-match": '"123"', "if-modified-since": last_modified})
assert second_resp.status_code == 200
assert second_resp.content == b""
def test_staticfiles_304_with_last_modified_compare_last_req(
tmpdir: Path, test_client_factory: TestClientFactory
) -> None:
path = os.path.join(tmpdir, "example.txt")
file_last_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S"))
with open(path, "w") as file:
file.write("")
os.utime(path, (file_last_modified_time, file_last_modified_time))
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
# last modified less than last request, 304
response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"})
assert response.status_code == 304
assert response.content == b""
# last modified greater than last request, 200 with content
response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"})
assert response.status_code == 200
assert response.content == b""
def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "404.html")
with open(path, "w") as file:
file.write("
Custom not found page
")
path = os.path.join(tmpdir, "dir")
os.mkdir(path)
path = os.path.join(path, "index.html")
with open(path, "w") as file:
file.write("