Repository: aio-libs/aiohttp Branch: master Commit: 2602b711710d Files: 437 Total size: 3.9 MB Directory structure: gitextract_ex8k9ysz/ ├── .cherry_picker.toml ├── .codecov.yml ├── .coveragerc.toml ├── .editorconfig ├── .git-blame-ignore-revs ├── .gitattributes ├── .github/ │ ├── CODEOWNERS │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.yml │ │ ├── config.yml │ │ └── feature_request.yml │ ├── ISSUE_TEMPLATE.md │ ├── PULL_REQUEST_TEMPLATE.md │ ├── codeql.yml │ ├── config.yml │ ├── dependabot.yml │ ├── lock.yml │ └── workflows/ │ ├── auto-merge.yml │ ├── ci-cd.yml │ ├── codeql.yml │ ├── label-remove.yml │ ├── labels.yml │ └── stale.yml ├── .gitignore ├── .gitmodules ├── .lgtm.yml ├── .mypy.ini ├── .pip-tools.toml ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CHANGES/ │ ├── .TEMPLATE.rst │ ├── .gitignore │ ├── 10468.doc.rst │ ├── 10596.bugfix.rst │ ├── 10611.bugfix.rst │ ├── 10665.feature.rst │ ├── 10683.bugfix.rst │ ├── 10753.bugfix.rst │ ├── 10795.doc.rst │ ├── 11012.breaking.rst │ ├── 11268.feature.rst │ ├── 11283.bugfix.rst │ ├── 11601.breaking.rst │ ├── 11681.feature.rst │ ├── 11737.contrib.rst │ ├── 11763.feature.rst │ ├── 11766.feature.rst │ ├── 11776.misc.rst │ ├── 11826.contrib.rst │ ├── 11859.bugfix.rst │ ├── 11876.misc.rst │ ├── 11898.bugfix.rst │ ├── 11937.misc.rst │ ├── 11955.feature.rst │ ├── 11972.bugfix.rst │ ├── 11989.feature.rst │ ├── 11992.contrib.rst │ ├── 12027.misc.rst │ ├── 12030.bugfix.rst │ ├── 12042.doc.rst │ ├── 12069.packaging.rst │ ├── 12088.bugfix.rst │ ├── 12091.bugfix.rst │ ├── 12096.bugfix.rst │ ├── 12097.bugfix.rst │ ├── 12106.feature.rst │ ├── 12136.bugfix.rst │ ├── 12170.misc.rst │ ├── 12173.contrib.rst │ ├── 12195.bugfix.rst │ ├── 12231.bugfix.rst │ ├── 12240.bugfix.rst │ ├── 12249.bugfix.rst │ ├── 2174.bugfix │ ├── 2835.breaking.rst │ ├── 2977.breaking.rst │ ├── 3310.bugfix │ ├── 3462.feature │ ├── 3463.breaking.rst │ ├── 3482.bugfix │ ├── 3538.breaking.rst │ ├── 3539.breaking.rst │ ├── 3540.feature │ ├── 3542.breaking.rst │ ├── 3545.feature │ ├── 3547.breaking.rst │ ├── 3548.breaking.rst │ ├── 3559.doc │ ├── 3562.bugfix │ ├── 3569.feature │ ├── 3580.breaking.rst │ ├── 3612.bugfix │ ├── 3613.bugfix │ ├── 3642.doc │ ├── 3685.doc │ ├── 3721.bugfix │ ├── 3767.feature │ ├── 3787.feature │ ├── 3796.feature │ ├── 3890.breaking.rst │ ├── 3901.breaking.rst │ ├── 3929.breaking.rst │ ├── 3931.breaking.rst │ ├── 3932.breaking.rst │ ├── 3933.breaking.rst │ ├── 3934.breaking.rst │ ├── 3935.breaking.rst │ ├── 3939.breaking.rst │ ├── 3940.breaking.rst │ ├── 3942.breaking.rst │ ├── 3948.breaking.rst │ ├── 3994.misc │ ├── 4161.doc │ ├── 4277.feature │ ├── 4283.bugfix │ ├── 4299.bugfix │ ├── 4302.bugfix │ ├── 4368.bugfix │ ├── 4452.doc │ ├── 4504.doc │ ├── 4526.bugfix │ ├── 4558.bugfix │ ├── 4656.bugfix │ ├── 4695.doc │ ├── 4706.feature │ ├── 5075.feature │ ├── 5191.doc │ ├── 5258.bugfix │ ├── 5278.breaking.rst │ ├── 5284.breaking.rst │ ├── 5284.feature │ ├── 5287.feature │ ├── 5516.misc │ ├── 5533.misc │ ├── 5558.bugfix │ ├── 5634.feature │ ├── 5783.feature │ ├── 5806.misc │ ├── 5829.misc │ ├── 5870.misc │ ├── 5894.bugfix │ ├── 6180.bugfix │ ├── 6181.bugfix │ ├── 6193.feature │ ├── 6547.bugfix │ ├── 6721.misc │ ├── 6979.doc │ ├── 6998.doc │ ├── 7107.breaking.rst │ ├── 7265.breaking.rst │ ├── 7319.feature.rst │ ├── 7677.bugfix │ ├── 7772.bugfix │ ├── 7815.bugfix │ ├── 8048.breaking.rst │ ├── 8139.contrib.rst │ ├── 8197.doc │ ├── 8303.breaking.rst │ ├── 8596.breaking.rst │ ├── 8698.breaking.rst │ ├── 8957.breaking.rst │ ├── 9109.breaking.rst │ ├── 9212.packaging.rst │ ├── 9254.breaking.rst │ ├── 9292.breaking.rst │ ├── 9413.misc.rst │ └── README.rst ├── CHANGES.rst ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.rst ├── CONTRIBUTORS.txt ├── LICENSE.txt ├── MANIFEST.in ├── Makefile ├── README.rst ├── aiohttp/ │ ├── __init__.py │ ├── _cookie_helpers.py │ ├── _cparser.pxd │ ├── _find_header.h │ ├── _find_header.pxd │ ├── _http_parser.pyx │ ├── _http_writer.pyx │ ├── _websocket/ │ │ ├── __init__.py │ │ ├── helpers.py │ │ ├── mask.pxd │ │ ├── mask.pyx │ │ ├── models.py │ │ ├── reader.py │ │ ├── reader_c.pxd │ │ ├── reader_py.py │ │ └── writer.py │ ├── abc.py │ ├── base_protocol.py │ ├── client.py │ ├── client_exceptions.py │ ├── client_middleware_digest_auth.py │ ├── client_middlewares.py │ ├── client_proto.py │ ├── client_reqrep.py │ ├── client_ws.py │ ├── compression_utils.py │ ├── connector.py │ ├── cookiejar.py │ ├── formdata.py │ ├── hdrs.py │ ├── helpers.py │ ├── http.py │ ├── http_exceptions.py │ ├── http_parser.py │ ├── http_websocket.py │ ├── http_writer.py │ ├── log.py │ ├── multipart.py │ ├── payload.py │ ├── py.typed │ ├── pytest_plugin.py │ ├── resolver.py │ ├── streams.py │ ├── tcp_helpers.py │ ├── test_utils.py │ ├── tracing.py │ ├── typedefs.py │ ├── web.py │ ├── web_app.py │ ├── web_exceptions.py │ ├── web_fileresponse.py │ ├── web_log.py │ ├── web_middlewares.py │ ├── web_protocol.py │ ├── web_request.py │ ├── web_response.py │ ├── web_routedef.py │ ├── web_runner.py │ ├── web_server.py │ ├── web_urldispatcher.py │ ├── web_ws.py │ └── worker.py ├── docs/ │ ├── Makefile │ ├── _static/ │ │ └── css/ │ │ └── logo-adjustments.css │ ├── abc.rst │ ├── built_with.rst │ ├── changes.rst │ ├── client.rst │ ├── client_advanced.rst │ ├── client_middleware_cookbook.rst │ ├── client_quickstart.rst │ ├── client_reference.rst │ ├── code/ │ │ └── client_middleware_cookbook.py │ ├── conf.py │ ├── contributing-admins.rst │ ├── contributing.rst │ ├── deployment.rst │ ├── essays.rst │ ├── external.rst │ ├── faq.rst │ ├── glossary.rst │ ├── http_request_lifecycle.rst │ ├── index.rst │ ├── logging.rst │ ├── make.bat │ ├── migration_to_2xx.rst │ ├── misc.rst │ ├── multipart.rst │ ├── multipart_reference.rst │ ├── new_router.rst │ ├── powered_by.rst │ ├── spelling_wordlist.txt │ ├── streams.rst │ ├── structures.rst │ ├── testing.rst │ ├── third_party.rst │ ├── tracing_reference.rst │ ├── utilities.rst │ ├── web.rst │ ├── web_advanced.rst │ ├── web_exceptions.rst │ ├── web_lowlevel.rst │ ├── web_quickstart.rst │ ├── web_reference.rst │ ├── websocket_utilities.rst │ ├── whats_new_1_1.rst │ └── whats_new_3_0.rst ├── examples/ │ ├── background_tasks.py │ ├── basic_auth_middleware.py │ ├── cli_app.py │ ├── client_auth.py │ ├── client_json.py │ ├── client_ws.py │ ├── combined_middleware.py │ ├── curl.py │ ├── digest_auth_qop_auth.py │ ├── fake_server.py │ ├── logging_middleware.py │ ├── lowlevel_srv.py │ ├── retry_middleware.py │ ├── server.crt │ ├── server.csr │ ├── server.key │ ├── server_simple.py │ ├── static_files.py │ ├── token_refresh_middleware.py │ ├── web_classview.py │ ├── web_cookies.py │ ├── web_rewrite_headers_middleware.py │ ├── web_srv.py │ ├── web_srv_route_deco.py │ ├── web_srv_route_table.py │ ├── web_ws.py │ └── websocket.html ├── pyproject.toml ├── requirements/ │ ├── base-ft.in │ ├── base-ft.txt │ ├── base.in │ ├── base.txt │ ├── constraints.in │ ├── constraints.txt │ ├── cython.in │ ├── cython.txt │ ├── dev.in │ ├── dev.txt │ ├── doc-spelling.in │ ├── doc-spelling.txt │ ├── doc.in │ ├── doc.txt │ ├── lint.in │ ├── lint.txt │ ├── multidict.in │ ├── multidict.txt │ ├── runtime-deps.in │ ├── runtime-deps.txt │ ├── sync-direct-runtime-deps.py │ ├── test-common.in │ ├── test-common.txt │ ├── test-ft.in │ ├── test-ft.txt │ ├── test.in │ └── test.txt ├── setup.cfg ├── setup.py ├── tests/ │ ├── autobahn/ │ │ ├── Dockerfile.aiohttp │ │ ├── Dockerfile.autobahn │ │ ├── client/ │ │ │ ├── client.py │ │ │ └── fuzzingserver.json │ │ ├── server/ │ │ │ ├── fuzzingclient.json │ │ │ └── server.py │ │ └── test_autobahn.py │ ├── conftest.py │ ├── data.unknown_mime_type │ ├── data.zero_bytes │ ├── github-urls.json │ ├── isolated/ │ │ ├── check_for_client_response_leak.py │ │ └── check_for_request_leak.py │ ├── sample.txt │ ├── test_base_protocol.py │ ├── test_benchmarks_client.py │ ├── test_benchmarks_client_request.py │ ├── test_benchmarks_client_ws.py │ ├── test_benchmarks_cookiejar.py │ ├── test_benchmarks_http_websocket.py │ ├── test_benchmarks_http_writer.py │ ├── test_benchmarks_web_fileresponse.py │ ├── test_benchmarks_web_middleware.py │ ├── test_benchmarks_web_response.py │ ├── test_benchmarks_web_urldispatcher.py │ ├── test_circular_imports.py │ ├── test_classbasedview.py │ ├── test_client_connection.py │ ├── test_client_exceptions.py │ ├── test_client_fingerprint.py │ ├── test_client_functional.py │ ├── test_client_middleware.py │ ├── test_client_middleware_digest_auth.py │ ├── test_client_proto.py │ ├── test_client_request.py │ ├── test_client_response.py │ ├── test_client_session.py │ ├── test_client_ws.py │ ├── test_client_ws_functional.py │ ├── test_compression_utils.py │ ├── test_connector.py │ ├── test_cookie_helpers.py │ ├── test_cookiejar.py │ ├── test_flowcontrol_streams.py │ ├── test_formdata.py │ ├── test_helpers.py │ ├── test_http_exceptions.py │ ├── test_http_parser.py │ ├── test_http_writer.py │ ├── test_imports.py │ ├── test_leaks.py │ ├── test_loop.py │ ├── test_multipart.py │ ├── test_multipart_helpers.py │ ├── test_payload.py │ ├── test_proxy.py │ ├── test_proxy_functional.py │ ├── test_pytest_plugin.py │ ├── test_resolver.py │ ├── test_route_def.py │ ├── test_run_app.py │ ├── test_streams.py │ ├── test_tcp_helpers.py │ ├── test_test_utils.py │ ├── test_tracing.py │ ├── test_urldispatch.py │ ├── test_web_app.py │ ├── test_web_cli.py │ ├── test_web_exceptions.py │ ├── test_web_functional.py │ ├── test_web_log.py │ ├── test_web_middleware.py │ ├── test_web_protocol.py │ ├── test_web_request.py │ ├── test_web_request_handler.py │ ├── test_web_response.py │ ├── test_web_runner.py │ ├── test_web_sendfile.py │ ├── test_web_sendfile_functional.py │ ├── test_web_server.py │ ├── test_web_urldispatcher.py │ ├── test_web_websocket.py │ ├── test_web_websocket_functional.py │ ├── test_websocket_data_queue.py │ ├── test_websocket_handshake.py │ ├── test_websocket_parser.py │ ├── test_websocket_writer.py │ └── test_worker.py ├── tools/ │ ├── bench-asyncio-write.py │ ├── check_changes.py │ ├── check_sum.py │ ├── cleanup_changes.py │ ├── drop_merged_branches.sh │ ├── gen.py │ └── testing/ │ ├── Dockerfile │ ├── Dockerfile.dockerignore │ └── entrypoint.sh └── vendor/ └── README.rst ================================================ FILE CONTENTS ================================================ ================================================ FILE: .cherry_picker.toml ================================================ team = "aio-libs" repo = "aiohttp" check_sha = "f382b5ffc445e45a110734f5396728da7914aeb6" fix_commit_msg = false ================================================ FILE: .codecov.yml ================================================ codecov: branch: master notify: manual_trigger: true comment: require_head: false require_base: false coverage: range: "95..100" status: project: no component_management: individual_components: - component_id: project paths: - aiohttp/** - component_id: tests paths: - tests/** flags: library: paths: - aiohttp/ configs: paths: - requirements/ - ".git*" - "*.toml" - "*.yml" changelog: paths: - CHANGES/ - CHANGES.rst docs: paths: - docs/ - "*.md" - "*.rst" - "*.txt" tests: paths: - tests/ tools: paths: - tools/ third-party: paths: - vendor/ ================================================ FILE: .coveragerc.toml ================================================ [run] branch = true # NOTE: `ctrace` tracing method is needed because the `sysmon` tracer # NOTE: which is default on Python 3.14, causes unprecedented slow-down # NOTE: of the test runs. # Ref: https://github.com/coveragepy/coveragepy/issues/2099 core = 'ctrace' source = [ 'aiohttp', 'tests', ] omit = [ 'site-packages', ] [report] exclude_also = [ 'if TYPE_CHECKING', 'assert False', ': \.\.\.(\s*#.*)?$', '^ +\.\.\.$', 'pytest.fail\(' ] ================================================ FILE: .editorconfig ================================================ # EditorConfig is awesome: http://EditorConfig.org # top-most EditorConfig file root = true # Unix-style newlines with a newline ending every file [*] end_of_line = lf insert_final_newline = true indent_style = space indent_size = 4 trim_trailing_whitespace = true charset = utf-8 [Makefile] indent_style = tab [*.{yml,yaml}] indent_size = 2 [*.rst] max_line_length = 80 ================================================ FILE: .git-blame-ignore-revs ================================================ # git hyper-blame master ignore list. # # This file contains a list of git hashes of revisions to be ignored by git # hyper-blame (in depot_tools). These revisions are considered "unimportant" in # that they are unlikely to be what you are interested in when blaming. # # Instructions: # - Only large (generally automated) reformatting or renaming CLs should be # added to this list. Do not put things here just because you feel they are # trivial or unimportant. If in doubt, do not put it on this list. # - Precede each revision with a comment containing the first line of its log. # For bulk work over many commits, place all commits in a block with a single # comment at the top describing the work done in those commits. # - Only put full 40-character hashes on this list (not short hashes or any # other revision reference). # - Append to the bottom of the file (revisions should be in chronological order # from oldest to newest). # - Because you must use a hash, you need to append to this list in a follow-up # CL to the actual reformatting CL that you are trying to ignore. # Black 6ab76b084bf5012b7185046162ed92bedcf073b5 # Apply new hooks 41c5467a62fb1041b77356ea22b81a74305941ef # Tune C source generation 3f7d64798d46b8b166139811a377ba231b4f36bf # Apply pyupgrade 32833c3fe081ff75c0b08ace9cc71e821a72fc5e ================================================ FILE: .gitattributes ================================================ tests/data.unknown_mime_type binary tests/sample.* binary ================================================ FILE: .github/CODEOWNERS ================================================ * @asvetlov /.github/* @webknjaz @asvetlov /.circleci/* @webknjaz @asvetlov /CHANGES/* @asvetlov /docs/* @asvetlov /examples/* @asvetlov /requirements/* @webknjaz @asvetlov /tests/* @asvetlov /tools/* @webknjaz @asvetlov /vendor/* @webknjaz @asvetlov *.ini @webknjaz @asvetlov *.md @webknjaz @asvetlov *.rst @webknjaz @asvetlov *.toml @webknjaz @asvetlov *.txt @webknjaz @asvetlov *.yml @webknjaz @asvetlov *.yaml @webknjaz @asvetlov .editorconfig @webknjaz @asvetlov .git* @webknjaz Makefile @webknjaz @asvetlov setup.py @webknjaz @asvetlov setup.cfg @webknjaz @asvetlov tox.ini @webknjaz ================================================ FILE: .github/FUNDING.yml ================================================ --- github: - asvetlov - webknjaz - Dreamsorcerer open_collective: aiohttp tidelift: pypi/aiohttp # A single Tidelift platform-name/package-name e.g., npm/babel custom: - https://opencollective.com/aio-libs ... ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.yml ================================================ --- name: Bug Report description: Create a report to help us improve. labels: [bug] assignees: aio-libs/triagers body: - type: markdown attributes: value: | **Thanks for taking a minute to file a bug report!** ⚠ Verify first that your issue is not [already reported on GitHub][issue search]. _Please fill out the form below with as many precise details as possible._ [issue search]: ../search?q=is%3Aissue&type=issues - type: textarea attributes: label: Describe the bug description: >- A clear and concise description of what the bug is. validations: required: true - type: textarea attributes: label: To Reproduce description: >- Describe the steps to reproduce this bug. placeholder: | 1. Implement the following server or a client '...' 2. Then run '...' 3. An error occurs. The chances of someone looking at your issue are *vastly* improved if you provide complete code that can be copy/pasted and executed directly in Python. validations: required: true - type: textarea attributes: label: Expected behavior description: >- A clear and concise description of what you expected to happen. validations: required: true - type: textarea attributes: label: Logs/tracebacks description: | If applicable, add logs/tracebacks to help explain your problem. Paste the output of the steps above, including the commands themselves and their output/traceback etc. render: python-traceback validations: required: true - type: textarea attributes: label: Python Version description: Attach your version of Python. render: console value: | $ python --version validations: required: true - type: textarea attributes: label: aiohttp Version description: Attach your version of aiohttp. render: console value: | $ python -m pip show aiohttp validations: required: true - type: textarea attributes: label: multidict Version description: Attach your version of multidict. render: console value: | $ python -m pip show multidict validations: required: true - type: textarea attributes: label: propcache Version description: Attach your version of propcache. render: console value: | $ python -m pip show propcache validations: required: true - type: textarea attributes: label: yarl Version description: Attach your version of yarl. render: console value: | $ python -m pip show yarl validations: required: true - type: textarea attributes: label: OS placeholder: >- For example, Arch Linux, Windows, macOS, etc. validations: required: true - type: dropdown attributes: label: Related component description: >- aiohttp is both server framework and client library. For getting rid of confusing make sure to select 'server', 'client' or both. multiple: true options: - Server - Client validations: required: true - type: textarea attributes: label: Additional context description: | Add any other context about the problem here. Describe the environment you have that lead to your issue. This includes proxy server and other bits that are related to your case. - type: checkboxes attributes: label: Code of Conduct description: | Read the [aio-libs Code of Conduct][CoC] first. [CoC]: https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md options: - label: I agree to follow the aio-libs Code of Conduct required: true ... ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ # Ref: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository#configuring-the-template-chooser blank_issues_enabled: false # default: true contact_links: - name: 🤷💻🤦 StackOverflow url: https://stackoverflow.com/questions/tagged/aiohttp about: Please ask typical Q&A here - name: 💬 Github Discussions url: https://github.com/aio-libs/aiohttp/discussions about: Please start usage discussions here - name: 💬 Gitter Chat url: https://gitter.im/aio-libs/Lobby about: Chat with devs and community ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.yml ================================================ --- name: 🚀 Feature request description: Suggest an idea for this project. labels: enhancement body: - type: markdown attributes: value: | **Thanks for taking a minute to file a feature for aiohttp!** ⚠ Verify first that your feature request is not [already reported on GitHub][issue search]. _Please fill out the form below with as many precise details as possible._ [issue search]: ../search?q=is%3Aissue&type=issues - type: textarea attributes: label: Is your feature request related to a problem? description: >- Please add a clear and concise description of what the problem is. _Ex. I'm always frustrated when [...]_ - type: textarea attributes: label: Describe the solution you'd like description: >- A clear and concise description of what you want to happen. validations: required: true - type: textarea attributes: label: Describe alternatives you've considered description: >- A clear and concise description of any alternative solutions or features you've considered. validations: required: true - type: dropdown attributes: label: Related component description: >- aiohttp is both server framework and client library. For getting rid of confusing make sure to select 'server', 'client' or both. multiple: true options: - Server - Client validations: required: true - type: textarea attributes: label: Additional context description: >- Add any other context or screenshots about the feature request here. - type: checkboxes attributes: label: Code of Conduct description: | Read the [aio-libs Code of Conduct][CoC] first. [CoC]: https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md options: - label: I agree to follow the aio-libs Code of Conduct required: true ... ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ ## Long story short ## Expected behaviour ## Actual behaviour ## Steps to reproduce ## Your environment ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ ## What do these changes do? ## Are there changes in behavior for the user? ## Is it a substantial burden for the maintainers to support this? ## Related issue number ## Checklist - [ ] I think the code is well written - [ ] Unit tests for the changes exist - [ ] Documentation reflects the changes - [ ] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` * The format is <Name> <Surname>. * Please keep alphabetical order, the file is sorted by names. - [ ] Add a new news fragment into the `CHANGES/` folder * name it `..rst` (e.g. `588.bugfix.rst`) * if you don't have an issue number, change it to the pull request number after creating the PR * `.bugfix`: A bug fix for something the maintainers deemed an improper undesired behavior that got corrected to match pre-agreed expectations. * `.feature`: A new behavior, public APIs. That sort of stuff. * `.deprecation`: A declaration of future API removals and breaking changes in behavior. * `.breaking`: When something public is removed in a breaking way. Could be deprecated in an earlier release. * `.doc`: Notable updates to the documentation structure or build process. * `.packaging`: Notes for downstreams about unobvious side effects and tooling. Changes in the test invocation considerations and runtime assumptions. * `.contrib`: Stuff that affects the contributor experience. e.g. Running tests, building the docs, setting up the development environment. * `.misc`: Changes that are hard to assign to any of the above categories. * Make sure to use full sentences with correct case and punctuation, for example: ```rst Fixed issue with non-ascii contents in doctest text files -- by :user:`contributor-gh-handle`. ``` Use the past tense or the present tense a non-imperative mood, referring to what's changed compared to the last released version of this project. ================================================ FILE: .github/codeql.yml ================================================ query-filters: - exclude: id: - py/ineffectual-statement - py/unsafe-cyclic-import ================================================ FILE: .github/config.yml ================================================ chronographer: exclude: bots: - dependabot-preview - dependabot - patchback humans: - pyup-bot ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: # Maintain dependencies for GitHub Actions - package-ecosystem: "github-actions" directory: "/" labels: - dependencies schedule: interval: "daily" # Maintain dependencies for Python - package-ecosystem: "pip" directory: "/" allow: - dependency-type: "all" labels: - dependencies schedule: interval: "daily" open-pull-requests-limit: 10 # Maintain dependencies for GitHub Actions aiohttp backport - package-ecosystem: "github-actions" directory: "/" labels: - dependencies target-branch: "3.14" schedule: interval: "daily" open-pull-requests-limit: 10 # Maintain dependencies for Python aiohttp backport - package-ecosystem: "pip" directory: "/" allow: - dependency-type: "all" labels: - dependencies target-branch: "3.14" schedule: interval: "daily" open-pull-requests-limit: 10 - package-ecosystem: "docker" directory: "/tests/autobahn/" labels: - dependencies schedule: interval: "monthly" - package-ecosystem: "docker" directory: "/tests/autobahn/" labels: - dependencies target-branch: "3.14" schedule: interval: "monthly" ================================================ FILE: .github/lock.yml ================================================ # Configuration for Lock Threads - https://github.com/dessant/lock-threads # GitHub App - https://github.com/apps/lock --- # Number of days of inactivity before a closed issue or pull request is locked daysUntilLock: 365 # Skip issues and pull requests created before a given timestamp. Timestamp must # follow ISO 8601 (`YYYY-MM-DD`). Set to `false` to disable skipCreatedBefore: false # Issues and pull requests with these labels will be ignored. # Set to `[]` to disable exemptLabels: [] # Label to add before locking, such as `outdated`. # Set to `false` to disable lockLabel: outdated # Comment to post before locking. Set to `false` to disable lockComment: false # Assign `resolved` as the reason for locking. Set to `false` to disable setLockReason: true # Limit to only `issues` or `pulls` # only: issues # Optionally, specify configuration settings just for `issues` or `pulls` # issues: # exemptLabels: # - help-wanted # lockLabel: outdated # pulls: # daysUntilLock: 30 # Repository to extend settings from # _extends: repo ================================================ FILE: .github/workflows/auto-merge.yml ================================================ name: Dependabot auto-merge on: pull_request_target permissions: pull-requests: write contents: write jobs: dependabot: runs-on: ubuntu-latest if: ${{ github.actor == 'dependabot[bot]' }} steps: - name: Dependabot metadata id: metadata uses: dependabot/fetch-metadata@v2 with: github-token: "${{ secrets.GITHUB_TOKEN }}" - name: Enable auto-merge for Dependabot PRs run: gh pr merge --auto --squash "$PR_URL" env: PR_URL: ${{github.event.pull_request.html_url}} GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} ================================================ FILE: .github/workflows/ci-cd.yml ================================================ name: CI on: merge_group: push: branches: - 'master' - '[0-9].[0-9]+' # matches to backport branches, e.g. 3.6 tags: [ 'v*' ] pull_request: branches: - 'master' - '[0-9].[0-9]+' schedule: - cron: '0 6 * * *' # Daily 6AM UTC build concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} cancel-in-progress: true env: COLOR: yes FORCE_COLOR: 1 # Request colored output from CLI tools supporting it MYPY_FORCE_COLOR: 1 PY_COLORS: 1 UPSTREAM_REPOSITORY_ID: >- 13258039 permissions: {} jobs: pre-setup: name: Pre-Setup global build settings runs-on: ubuntu-latest outputs: upstream-repository-id: ${{ env.UPSTREAM_REPOSITORY_ID }} release-requested: >- ${{ ( github.event_name == 'push' && github.ref_type == 'tag' ) && true || false }} steps: - name: Dummy if: false run: | echo "Pre-setup step" lint: permissions: contents: read # to fetch code (actions/checkout) name: Linter runs-on: ubuntu-latest timeout-minutes: 5 steps: - name: Checkout uses: actions/checkout@v6 with: submodules: true - name: >- Verify that `requirements/runtime-deps.in` is in sync with `pyproject.toml` run: | set -eEuo pipefail make sync-direct-runtime-deps git diff --exit-code -- requirements/runtime-deps.in - name: Setup Python uses: actions/setup-python@v6 with: python-version: 3.11 - name: Cache PyPI uses: actions/cache@v5.0.4 with: key: pip-lint-${{ hashFiles('requirements/*.txt') }} path: ~/.cache/pip restore-keys: | pip-lint- - name: Update pip, wheel, setuptools, build, twine run: | python -m pip install -U pip wheel setuptools build twine - name: Install dependencies run: | python -m pip install -r requirements/lint.in -c requirements/lint.txt - name: Install self run: | python -m pip install . -c requirements/runtime-deps.txt env: AIOHTTP_NO_EXTENSIONS: 1 - name: Run mypy run: | make mypy - name: Run slotscheck run: | # Some extra requirements are needed to ensure all modules # can be scanned by slotscheck. pip install -r requirements/base.in -c requirements/base.txt slotscheck -v -m aiohttp - name: Install spell checker run: | pip install -r requirements/doc-spelling.in -c requirements/doc-spelling.txt - name: Run docs spelling run: | # towncrier --yes # uncomment me after publishing a release make doc-spelling - name: Build package run: | python -m build env: AIOHTTP_NO_EXTENSIONS: 1 - name: Run twine checker run: | twine check --strict dist/* - name: Making sure that CONTRIBUTORS.txt remains sorted run: | LC_ALL=C sort --check --ignore-case CONTRIBUTORS.txt gen_llhttp: permissions: contents: read # to fetch code (actions/checkout) name: Generate llhttp sources runs-on: ubuntu-latest timeout-minutes: 5 steps: - name: Checkout uses: actions/checkout@v6 with: submodules: true - name: Cache llhttp generated files uses: actions/cache@v5.0.4 id: cache with: key: llhttp-${{ hashFiles('vendor/llhttp/package*.json', 'vendor/llhttp/src/**/*') }} path: vendor/llhttp/build - name: Setup NodeJS if: steps.cache.outputs.cache-hit != 'true' uses: actions/setup-node@v6 with: node-version: 18 - name: Generate llhttp sources if: steps.cache.outputs.cache-hit != 'true' run: | make generate-llhttp - name: Upload llhttp generated files uses: actions/upload-artifact@v6 with: name: llhttp path: vendor/llhttp/build if-no-files-found: error test: permissions: contents: read # to fetch code (actions/checkout) name: Test needs: gen_llhttp strategy: matrix: pyver: ['3.10', '3.11', '3.12', '3.13', '3.14'] no-extensions: ['', 'Y'] os: [ubuntu, macos, windows] experimental: [false] exclude: - os: macos no-extensions: 'Y' - os: windows no-extensions: 'Y' include: - pyver: pypy-3.11 no-extensions: 'Y' os: ubuntu experimental: false - os: ubuntu pyver: "3.14t" no-extensions: '' experimental: false fail-fast: true runs-on: ${{ matrix.os }}-latest continue-on-error: ${{ matrix.experimental }} steps: - name: Checkout uses: actions/checkout@v6 with: submodules: true - name: Setup Python ${{ matrix.pyver }} id: python-install uses: actions/setup-python@v6 with: allow-prereleases: true python-version: ${{ matrix.pyver }} - name: Get pip cache dir id: pip-cache run: | echo "dir=$(pip cache dir)" >> "${GITHUB_OUTPUT}" shell: bash - name: Cache PyPI uses: actions/cache@v5.0.4 with: key: pip-ci-${{ runner.os }}-${{ matrix.pyver }}-${{ matrix.no-extensions }}-${{ hashFiles('requirements/*.txt') }} path: ${{ steps.pip-cache.outputs.dir }} restore-keys: | pip-ci-${{ runner.os }}-${{ matrix.pyver }}-${{ matrix.no-extensions }}- - name: Update pip, wheel, setuptools, build, twine run: | python -m pip install -U pip wheel setuptools build twine - name: Install dependencies env: DEPENDENCY_GROUP: test${{ endsWith(matrix.pyver, 't') && '-ft' || '' }} run: | python -Im pip install -r requirements/${{ env.DEPENDENCY_GROUP }}.in -c requirements/${{ env.DEPENDENCY_GROUP }}.txt - name: Set PYTHON_GIL=0 for free-threading builds if: ${{ endsWith(matrix.pyver, 't') }} run: echo "PYTHON_GIL=0" >> $GITHUB_ENV - name: Restore llhttp generated files if: ${{ matrix.no-extensions == '' }} uses: actions/download-artifact@v8 with: name: llhttp path: vendor/llhttp/build/ - name: Cythonize if: ${{ matrix.no-extensions == '' }} run: | make cythonize - name: Install self env: AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} run: python -m pip install -e . - name: Run unittests env: COLOR: yes AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} PIP_USER: 1 run: >- PATH="${HOME}/Library/Python/3.11/bin:${HOME}/.local/bin:${PATH}" pytest --junitxml=junit.xml -m 'not dev_mode and not autobahn' shell: bash - name: Re-run the failing tests with maximum verbosity if: failure() env: COLOR: yes AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} run: >- # `exit 1` makes sure that the job remains red with flaky runs pytest --no-cov --numprocesses=0 -vvvvv --lf && exit 1 shell: bash - name: Run dev_mode tests env: COLOR: yes AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} PIP_USER: 1 PYTHONDEVMODE: 1 run: pytest -m dev_mode --cov-append --numprocesses=0 shell: bash - name: Turn coverage into xml env: COLOR: 'yes' PIP_USER: 1 run: | python -m coverage xml - name: Upload coverage uses: codecov/codecov-action@v5 with: files: ./coverage.xml flags: >- CI-GHA,OS-${{ runner.os }},VM-${{ matrix.os }},Py-${{ steps.python-install.outputs.python-version }} token: ${{ secrets.CODECOV_TOKEN }} - name: Upload test results to Codecov if: ${{ !cancelled() }} uses: codecov/test-results-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} autobahn: permissions: contents: read # to fetch code (actions/checkout) name: Autobahn testsuite needs: gen_llhttp strategy: matrix: pyver: ['3.14'] no-extensions: [''] os: [ubuntu] fail-fast: true runs-on: ${{ matrix.os }}-latest steps: - name: Checkout uses: actions/checkout@v6 with: submodules: true - name: Setup Python ${{ matrix.pyver }} id: python-install uses: actions/setup-python@v6 with: allow-prereleases: true python-version: ${{ matrix.pyver }} - name: Get pip cache dir id: pip-cache run: | echo "dir=$(pip cache dir)" >> "${GITHUB_OUTPUT}" shell: bash - name: Cache PyPI uses: actions/cache@v5.0.4 with: key: pip-ci-${{ runner.os }}-${{ matrix.pyver }}-${{ matrix.no-extensions }}-${{ hashFiles('requirements/*.txt') }} path: ${{ steps.pip-cache.outputs.dir }} restore-keys: | pip-ci-${{ runner.os }}-${{ matrix.pyver }}-${{ matrix.no-extensions }}- - name: Update pip, wheel, setuptools, build, twine run: | python -m pip install -U pip wheel setuptools build twine - name: Install dependencies env: DEPENDENCY_GROUP: test${{ endsWith(matrix.pyver, 't') && '-ft' || '' }} run: | python -Im pip install -r requirements/${{ env.DEPENDENCY_GROUP }}.in -c requirements/${{ env.DEPENDENCY_GROUP }}.txt - name: Restore llhttp generated files if: ${{ matrix.no-extensions == '' }} uses: actions/download-artifact@v8 with: name: llhttp path: vendor/llhttp/build/ - name: Cythonize if: ${{ matrix.no-extensions == '' }} run: | make cythonize - name: Install self env: AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} run: python -m pip install -e . - name: Run unittests env: COLOR: yes AIOHTTP_NO_EXTENSIONS: ${{ matrix.no-extensions }} PIP_USER: 1 run: >- PATH="${HOME}/Library/Python/3.11/bin:${HOME}/.local/bin:${PATH}" pytest --junitxml=junit.xml --numprocesses=0 -m autobahn shell: bash - name: Turn coverage into xml env: COLOR: 'yes' PIP_USER: 1 run: | python -m coverage xml - name: Upload coverage uses: codecov/codecov-action@v5 with: files: ./coverage.xml flags: >- CI-GHA,OS-${{ runner.os }},VM-${{ matrix.os }},Py-${{ steps.python-install.outputs.python-version }} token: ${{ secrets.CODECOV_TOKEN }} - name: Upload test results to Codecov if: ${{ !cancelled() }} uses: codecov/test-results-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} benchmark: name: Benchmark needs: - gen_llhttp - pre-setup # transitive, for accessing settings if: >- needs.pre-setup.outputs.upstream-repository-id == github.repository_id runs-on: ubuntu-latest timeout-minutes: 12 steps: - name: Checkout project uses: actions/checkout@v6 with: submodules: true - name: Setup Python 3.13.2 id: python-install uses: actions/setup-python@v6 with: python-version: 3.13.2 cache: pip cache-dependency-path: requirements/*.txt - name: Update pip, wheel, setuptools, build, twine run: | python -m pip install -U pip wheel setuptools build twine - name: Install dependencies run: | python -m pip install -r requirements/test.in -c requirements/test.txt - name: Restore llhttp generated files uses: actions/download-artifact@v8 with: name: llhttp path: vendor/llhttp/build/ - name: Cythonize run: | make cythonize - name: Install self run: python -m pip install -e . - name: Run benchmarks uses: CodSpeedHQ/action@v4 with: mode: instrumentation run: python -Im pytest --no-cov --numprocesses=0 -vvvvv --codspeed check: # This job does nothing and is only used for the branch protection if: always() needs: - lint - test - autobahn runs-on: ubuntu-latest steps: - name: Decide whether the needed jobs succeeded or failed uses: re-actors/alls-green@release/v1 with: jobs: ${{ toJSON(needs) }} - name: Trigger codecov notification uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true run_command: send-notifications pre-deploy: name: Pre-Deploy runs-on: ubuntu-latest needs: - check - pre-setup # transitive, for accessing settings if: fromJSON(needs.pre-setup.outputs.release-requested) steps: - name: Dummy run: | echo "Predeploy step" build-tarball: permissions: contents: read # to fetch code (actions/checkout) name: Tarball runs-on: ubuntu-latest needs: pre-deploy steps: - name: Checkout uses: actions/checkout@v6 with: submodules: true - name: Setup Python uses: actions/setup-python@v6 - name: Update pip, wheel, setuptools, build, twine run: | python -m pip install -U pip wheel setuptools build twine - name: Install cython run: >- python -m pip install -r requirements/cython.in -c requirements/cython.txt - name: Restore llhttp generated files uses: actions/download-artifact@v8 with: name: llhttp path: vendor/llhttp/build/ - name: Cythonize run: | make cythonize - name: Make sdist run: | python -m build --sdist - name: Upload artifacts uses: actions/upload-artifact@v6 with: name: dist-sdist path: dist build-wheels: permissions: contents: read # to fetch code (actions/checkout) name: Build wheels on ${{ matrix.os }} ${{ matrix.qemu }} ${{ matrix.musl }} runs-on: ${{ matrix.os }} needs: pre-deploy strategy: matrix: os: ["ubuntu-latest", "windows-latest", "windows-11-arm", "macos-latest", "ubuntu-24.04-arm"] qemu: [''] musl: [""] include: # Split ubuntu/musl jobs for the sake of speed-up - os: ubuntu-latest qemu: ppc64le musl: "" - os: ubuntu-latest qemu: ppc64le musl: musllinux - os: ubuntu-latest qemu: riscv64 musl: "" - os: ubuntu-latest qemu: riscv64 musl: musllinux - os: ubuntu-latest qemu: s390x musl: "" - os: ubuntu-latest qemu: s390x musl: musllinux - os: ubuntu-latest qemu: armv7l musl: "" - os: ubuntu-latest qemu: armv7l musl: musllinux - os: ubuntu-latest musl: musllinux - os: ubuntu-24.04-arm musl: musllinux steps: - name: Checkout uses: actions/checkout@v6 with: submodules: true - name: Set up QEMU if: ${{ matrix.qemu }} uses: docker/setup-qemu-action@v4 with: platforms: all # This should be temporary # xref https://github.com/docker/setup-qemu-action/issues/188 # xref https://github.com/tonistiigi/binfmt/issues/215 image: tonistiigi/binfmt:qemu-v8.1.5 id: qemu - name: Prepare emulation run: | if [[ -n "${{ matrix.qemu }}" ]]; then # Build emulated architectures only if QEMU is set, # use default "auto" otherwise echo "CIBW_ARCHS_LINUX=${{ matrix.qemu }}" >> $GITHUB_ENV fi shell: bash - name: Setup Python uses: actions/setup-python@v6 with: python-version: 3.x - name: Update pip, wheel, setuptools, build, twine run: | python -m pip install -U pip wheel setuptools build twine - name: Install cython run: >- python -m pip install -r requirements/cython.in -c requirements/cython.txt - name: Restore llhttp generated files uses: actions/download-artifact@v8 with: name: llhttp path: vendor/llhttp/build/ - name: Cythonize run: | make cythonize - name: Build wheels uses: pypa/cibuildwheel@v3.4.0 env: CIBW_SKIP: pp* ${{ matrix.musl == 'musllinux' && '*manylinux*' || '*musllinux*' }} CIBW_ARCHS_MACOS: x86_64 arm64 universal2 - name: Upload wheels uses: actions/upload-artifact@v6 with: name: >- dist-${{ matrix.os }}-${{ matrix.musl }}-${{ matrix.qemu && matrix.qemu || 'native' }} path: ./wheelhouse/*.whl deploy: name: Deploy needs: - build-tarball - build-wheels - pre-setup # transitive, for accessing settings runs-on: ubuntu-latest if: >- needs.pre-setup.outputs.upstream-repository-id == github.repository_id permissions: contents: write # IMPORTANT: mandatory for making GitHub Releases id-token: write # IMPORTANT: mandatory for trusted publishing & sigstore environment: name: pypi url: https://pypi.org/p/aiohttp steps: - name: Checkout uses: actions/checkout@v6 with: submodules: true - name: Login run: | echo "${{ secrets.GITHUB_TOKEN }}" | gh auth login --with-token - name: Download distributions uses: actions/download-artifact@v8 with: path: dist pattern: dist-* merge-multiple: true - name: Collected dists run: | tree dist - name: Make Release uses: aio-libs/create-release@v1.6.6 with: changes_file: CHANGES.rst name: aiohttp version_file: aiohttp/__init__.py github_token: ${{ secrets.GITHUB_TOKEN }} dist_dir: dist fix_issue_regex: >- :issue:`(\d+)` fix_issue_repl: >- #\1 - name: >- Publish 🐍📦 to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - name: Sign the dists with Sigstore uses: sigstore/gh-action-sigstore-python@v3.2.0 with: inputs: >- ./dist/*.tar.gz ./dist/*.whl - name: Upload artifact signatures to GitHub Release # Confusingly, this action also supports updating releases, not # just creating them. This is what we want here, since we've manually # created the release above. uses: softprops/action-gh-release@v2 with: # dist/ contains the built packages, which smoketest-artifacts/ # contains the signatures and certificates. files: dist/** ================================================ FILE: .github/workflows/codeql.yml ================================================ name: "CodeQL" on: push: branches: - 'master' - '[0-9].[0-9]+' # matches to backport branches, e.g. 3.6 pull_request: branches: [ "master" ] schedule: - cron: "9 1 * * 4" jobs: analyze: name: Analyze runs-on: ubuntu-latest permissions: actions: read contents: read security-events: write strategy: fail-fast: false matrix: language: [ python, javascript ] steps: - name: Checkout uses: actions/checkout@v6 - name: Initialize CodeQL uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} config-file: ./.github/codeql.yml queries: +security-and-quality - name: Autobuild uses: github/codeql-action/autobuild@v4 if: ${{ matrix.language == 'python' || matrix.language == 'javascript' }} - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v4 with: category: "/language:${{ matrix.language }}" ================================================ FILE: .github/workflows/label-remove.yml ================================================ name: Clear needs-info/pr-unfinished on activity on: pull_request: types: [synchronize, review_requested] issue_comment: types: [created] pull_request_review_comment: types: [created] jobs: clear-pr-unfinished: runs-on: ubuntu-latest permissions: pull-requests: write steps: - name: Remove label uses: actions-ecosystem/action-remove-labels@v1 with: labels: | needs-info pr-unfinished ================================================ FILE: .github/workflows/labels.yml ================================================ name: Labels on: pull_request: branches: - 'master' types: [labeled, opened, synchronize, reopened, unlabeled] jobs: backport: runs-on: ubuntu-latest name: Backport label added if: ${{ github.event.pull_request.user.type != 'Bot' }} steps: - uses: actions/github-script@v8 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | const pr = await github.rest.pulls.get({ owner: context.repo.owner, repo: context.repo.repo, pull_number: context.payload.pull_request.number }); if (!pr.data.labels.find(l => l.name.startsWith("backport"))) process.exit(1); ================================================ FILE: .github/workflows/stale.yml ================================================ name: 'Close stale issues' on: schedule: - cron: '50 5 * * *' permissions: issues: write jobs: stale: runs-on: ubuntu-latest steps: - uses: actions/stale@v10 with: days-before-stale: 30 any-of-labels: needs-info ================================================ FILE: .gitignore ================================================ *.bak *.egg *.egg-info *.eggs *.md5 *.pyc *.pyd *.pyo *.so *.swp *.tar.gz *~ .DS_Store .Python .cache .codspeed .coverage .coverage.* .develop .direnv .envrc .flake .gitconfig .hash .idea .install-cython .install-deps .llhttp-gen .installed.cfg .mypy_cache .noseids .pytest_cache .python-version .test-results .tox .vimrc .vscode aiohttp/_find_header.c aiohttp/_headers.html aiohttp/_headers.pxi aiohttp/_http_parser.c aiohttp/_http_parser.html aiohttp/_http_writer.c aiohttp/_http_writer.html aiohttp/_websocket.c aiohttp/_websocket.html aiohttp/_websocket/mask.c aiohttp/_websocket/reader_c.c bin build coverage.xml develop-eggs dist docs/_build/ eggs htmlcov include/ lib/ man/ nosetests.xml parts pip-wheel-metadata pyvenv sources var/* venv virtualenv.py ================================================ FILE: .gitmodules ================================================ [submodule "vendor/llhttp"] path = vendor/llhttp url = https://github.com/nodejs/llhttp.git branch = main ================================================ FILE: .lgtm.yml ================================================ queries: - exclude: py/unsafe-cyclic-import ================================================ FILE: .mypy.ini ================================================ [mypy] files = aiohttp, docs/code, examples, tests check_untyped_defs = True follow_imports_for_stubs = True disallow_any_decorated = True disallow_any_generics = True disallow_any_unimported = True disallow_incomplete_defs = True disallow_subclassing_any = True disallow_untyped_calls = True disallow_untyped_decorators = True disallow_untyped_defs = True # TODO(PY312): explicit-override enable_error_code = deprecated, exhaustive-match, ignore-without-code, possibly-undefined, redundant-expr, redundant-self, truthy-bool, truthy-iterable, unused-awaitable extra_checks = True follow_untyped_imports = True implicit_reexport = False no_implicit_optional = True pretty = True show_column_numbers = True show_error_codes = True show_error_code_links = True strict_bytes = True strict_equality = True warn_incomplete_stub = True warn_redundant_casts = True warn_return_any = True warn_unreachable = True warn_unused_ignores = True [mypy-brotli] ignore_missing_imports = True [mypy-brotlicffi] ignore_missing_imports = True [mypy-gunicorn.*] ignore_missing_imports = True ================================================ FILE: .pip-tools.toml ================================================ [pip-tools] allow-unsafe = true resolver = "backtracking" strip-extras = true ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: local hooks: - id: changelogs-rst name: changelog filenames language: fail entry: >- Changelog files must be named ####.( bugfix | feature | deprecation | breaking | doc | packaging | contrib | misc )(.#)?(.rst)? exclude: >- (?x) ^ CHANGES/( \.gitignore |(\d+|[0-9a-f]{8}|[0-9a-f]{7}|[0-9a-f]{40})\.( bugfix |feature |deprecation |breaking |doc |packaging |contrib |misc )(\.\d+)?(\.rst)? |README\.rst |\.TEMPLATE\.rst ) $ files: ^CHANGES/ - id: changelogs-user-role name: Changelog files should use a non-broken :user:`name` role language: pygrep entry: :user:([^`]+`?|`[^`]+[\s,]) pass_filenames: true types: [file, rst] - id: check-changes name: Check CHANGES language: system entry: ./tools/check_changes.py pass_filenames: false - repo: https://github.com/pre-commit/pre-commit-hooks rev: 'v6.0.0' hooks: - id: check-merge-conflict - repo: https://github.com/asottile/yesqa rev: v1.5.0 hooks: - id: yesqa additional_dependencies: - flake8-docstrings==1.6.0 - flake8-no-implicit-concat==0.3.4 - flake8-requirements==1.7.8 - repo: https://github.com/PyCQA/isort rev: '8.0.1' hooks: - id: isort - repo: https://github.com/psf/black-pre-commit-mirror rev: '26.3.1' hooks: - id: black language_version: python3 # Should be a command that runs python - repo: https://github.com/pre-commit/pre-commit-hooks rev: 'v6.0.0' hooks: - id: end-of-file-fixer exclude: >- ^docs/[^/]*\.svg$ - id: requirements-txt-fixer files: requirements/.*\.in$ - id: trailing-whitespace - id: file-contents-sorter args: ['--ignore-case'] files: | CONTRIBUTORS.txt| docs/spelling_wordlist.txt| .gitignore| .gitattributes - id: check-case-conflict - id: check-json - id: check-xml - id: check-executables-have-shebangs - id: check-toml - id: check-yaml - id: debug-statements - id: check-added-large-files - id: check-symlinks - id: fix-byte-order-marker - id: detect-aws-credentials args: ['--allow-missing-credentials'] - id: detect-private-key exclude: ^examples/ - repo: https://github.com/asottile/pyupgrade rev: 'v3.21.2' hooks: - id: pyupgrade args: ['--py37-plus'] - repo: https://github.com/PyCQA/flake8 rev: '7.3.0' hooks: - id: flake8 additional_dependencies: - flake8-docstrings==1.6.0 - flake8-no-implicit-concat==0.3.4 - flake8-requirements==1.7.8 exclude: "^docs/" - repo: https://github.com/Lucas-C/pre-commit-hooks-markup rev: v1.0.1 hooks: - id: rst-linter files: >- ^[^/]+[.]rst$ exclude: >- ^CHANGES\.rst$ - repo: https://github.com/codespell-project/codespell rev: v2.4.2 hooks: - id: codespell additional_dependencies: - tomli ================================================ FILE: .readthedocs.yml ================================================ # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html # for details --- version: 2 sphinx: # Path to your Sphinx configuration file. configuration: docs/conf.py submodules: include: all exclude: [] recursive: true build: os: ubuntu-24.04 tools: python: "3.11" apt_packages: - graphviz jobs: post_create_environment: - >- pip install . -c requirements/runtime-deps.txt -r requirements/doc.in -c requirements/doc.txt ... ================================================ FILE: CHANGES/.TEMPLATE.rst ================================================ {# TOWNCRIER TEMPLATE #} {% for section, _ in sections.items() %} {% set underline = underlines[0] %}{% if section %}{{section}} {{ underline * section|length }}{% set underline = underlines[1] %} {% endif %} {% if sections[section] %} {% for category, val in definitions.items() if category in sections[section]%} {{ definitions[category]['name'] }} {{ underline * definitions[category]['name']|length }} {% if definitions[category]['showcontent'] %} {% for text, change_note_refs in sections[section][category].items() %} - {{ text + '\n' }} {# NOTE: Replacing 'e' with 'f' is a hack that prevents Jinja's `int` NOTE: filter internal implementation from treating the input as an NOTE: infinite float when it looks like a scientific notation (with a NOTE: single 'e' char in between digits), raising an `OverflowError`, NOTE: subsequently. 'f' is still a hex letter so it won't affect the NOTE: check for whether it's a (short or long) commit hash or not. Ref: https://github.com/pallets/jinja/issues/1921 -#} {%- set pr_issue_numbers = change_note_refs | map('lower') | map('replace', 'e', 'f') | map('int', default=None) | select('integer') | map('string') | list -%} {%- set arbitrary_refs = [] -%} {%- set commit_refs = [] -%} {%- with -%} {%- set commit_ref_candidates = change_note_refs | reject('in', pr_issue_numbers) -%} {%- for cf in commit_ref_candidates -%} {%- if cf | length in (7, 8, 40) and cf | int(default=None, base=16) is not none -%} {%- set _ = commit_refs.append(cf) -%} {%- else -%} {%- set _ = arbitrary_refs.append(cf) -%} {%- endif -%} {%- endfor -%} {%- endwith -%} {% if pr_issue_numbers -%} *Related issues and pull requests on GitHub:* :issue:`{{ pr_issue_numbers | join('`, :issue:`') }}`. {% endif %} {% if commit_refs -%} *Related commits on GitHub:* :commit:`{{ commit_refs | join('`, :commit:`') }}`. {% endif %} {% if arbitrary_refs -%} *Unlinked references:* {{ arbitrary_refs | join(', ') }}`. {% endif %} {% endfor %} {% else %} - {{ sections[section][category]['']|join(', ') }} {% endif %} {% if sections[section][category]|length == 0 %} No significant changes. {% else %} {% endif %} {% endfor %} {% else %} No significant changes. {% endif %} {% endfor %} ---- {{ '\n' * 2 }} ================================================ FILE: CHANGES/.gitignore ================================================ * !.TEMPLATE.rst !.gitignore !README.rst !*.bugfix !*.bugfix.rst !*.bugfix.*.rst !*.breaking !*.breaking.rst !*.breaking.*.rst !*.contrib !*.contrib.rst !*.contrib.*.rst !*.deprecation !*.deprecation.rst !*.deprecation.*.rst !*.doc !*.doc.rst !*.doc.*.rst !*.feature !*.feature.rst !*.feature.*.rst !*.misc !*.misc.rst !*.misc.*.rst !*.packaging !*.packaging.rst !*.packaging.*.rst ================================================ FILE: CHANGES/10468.doc.rst ================================================ Added ``:canonical:`` directives to documentation reference pages, enabling ``Intersphinx`` cross-referencing via fully-qualified module paths (e.g. ``aiohttp.client.ClientSession``) -- by :user:`danielalanbates`. ================================================ FILE: CHANGES/10596.bugfix.rst ================================================ Fixed server hanging indefinitely when chunked transfer encoding chunk-size does not match actual data length. The server now raises ``TransferEncodingError`` instead of waiting forever for data that will never arrive -- by :user:`Fridayai700`. ================================================ FILE: CHANGES/10611.bugfix.rst ================================================ Reject HTTP requests with duplicate ``chunked`` Transfer-Encoding (e.g. ``Transfer-Encoding: chunked, chunked``) with a ``BadHttpMessage`` error, per :rfc:`9112` section 7.1 -- by :user:`worksbyfriday`. ================================================ FILE: CHANGES/10665.feature.rst ================================================ Added :py:attr:`~aiohttp.web.TCPSite.port` accessor for dynamic port allocations in :class:`~aiohttp.web.TCPSite` -- by :user:`twhittock-disguise` and :user:`rodrigobnogueira`. ================================================ FILE: CHANGES/10683.bugfix.rst ================================================ Fixed misleading TLS-in-TLS warning being emitted when sending HTTPS requests through an HTTP proxy. The warning now only fires when the proxy itself uses HTTPS, which is the only case where TLS-in-TLS actually applies -- by :user:`wavebyrd`. ================================================ FILE: CHANGES/10753.bugfix.rst ================================================ Widened ``trace_request_ctx`` parameter type from ``Mapping[str, Any] | None`` to ``object`` to allow passing instances of user-defined classes as trace context -- by :user:`nightcityblade`. ================================================ FILE: CHANGES/10795.doc.rst ================================================ Replaced the deprecated ``ujson`` library with ``orjson`` in the client quickstart documentation. ``ujson`` has been put into maintenance-only mode; ``orjson`` is the recommended alternative. -- by :user:`indoor47` ================================================ FILE: CHANGES/11012.breaking.rst ================================================ Refactored ``ClientRequest`` class. This simplifies a lot of code and improves our type checking accuracy. It also better aligns public/private attributes with what we expect developers to access safely from a client middleware. If code subclasses ``ClientRequest``, it is likely that the subclass will need tweaking to be compatible with the new version. Similarly, subclasses of ``ClientResponse`` may need to adjust ``__init__`` parameters. -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/11268.feature.rst ================================================ Updated ``_TracingSignal`` to utilize a secondary generic variable for type hinting custom context variables -- by :user:`Vizonex`. ================================================ FILE: CHANGES/11283.bugfix.rst ================================================ Fixed access log timestamps ignoring daylight saving time (DST) changes. The previous implementation used :py:data:`time.timezone` which is a constant and does not reflect DST transitions -- by :user:`nightcityblade`. ================================================ FILE: CHANGES/11601.breaking.rst ================================================ Dropped support for Python 3.9 -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/11681.feature.rst ================================================ Started accepting :term:`asynchronous context managers ` for cleanup contexts. Legacy single-yield :term:`asynchronous generator` cleanup contexts continue to be supported; async context managers are adapted internally so they are entered at startup and exited during cleanup. -- by :user:`MannXo`. ================================================ FILE: CHANGES/11737.contrib.rst ================================================ The benchmark CI job now runs only in the upstream repository -- by :user:`Cycloctane`. It used to always fail in forks, which this change fixed. ================================================ FILE: CHANGES/11763.feature.rst ================================================ Added ``decode_text`` parameter to :meth:`~aiohttp.ClientSession.ws_connect` and :class:`~aiohttp.web.WebSocketResponse` to receive WebSocket TEXT messages as raw bytes instead of decoded strings, enabling direct use with high-performance JSON parsers like ``orjson`` -- by :user:`bdraco`. ================================================ FILE: CHANGES/11766.feature.rst ================================================ Added ``RequestKey`` and ``ResponseKey`` classes, which enable static type checking for request & response context storages in the same way that ``AppKey`` does for ``Application`` -- by :user:`gsoldatov`. ================================================ FILE: CHANGES/11776.misc.rst ================================================ The warnings emitted when using ``str`` keys in ``web.Response``/``web.Request`` have been removed to avoid any performance concerns when frequently using these -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/11826.contrib.rst ================================================ The coverage tool is now configured using the new native auto-discovered :file:`.coveragerc.toml` file -- by :user:`webknjaz`. It is also set up to use the ``ctrace`` core that works around the performance issues in the ``sysmon`` tracer which is default under Python 3.14. ================================================ FILE: CHANGES/11859.bugfix.rst ================================================ Removed support for ``ClientTimeout(total=0)`` to disable timeouts. Use ``None`` instead of ``0`` to disable the total timeout. Passing ``0`` now raises :exc:`ValueError` with a clear error message -- by :user:`veeceey`. ================================================ FILE: CHANGES/11876.misc.rst ================================================ Refactored tests to use ``create_autospec()`` for more robust mocking -- by :user:`soheil-star01`. ================================================ FILE: CHANGES/11898.bugfix.rst ================================================ Restored :py:meth:`~aiohttp.BodyPartReader.decode` as a synchronous method for backward compatibility. The method was inadvertently changed to async in 3.13.3 as part of the decompression bomb security fix. A new :py:meth:`~aiohttp.BodyPartReader.decode_iter` method is now available for non-blocking decompression of large payloads using an async generator. Internal aiohttp code uses the async variant to maintain security protections. Changed multipart processing chunk sizes from 64 KiB to 256KiB, to better match aiohttp internals -- by :user:`bdraco` and :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/11937.misc.rst ================================================ Added win_arm64 to the wheels that gets pushed to PyPI -- by :user:`AraHaan`. ================================================ FILE: CHANGES/11955.feature.rst ================================================ Added ``max_headers`` parameter to limit the number of headers that should be read from a response -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/11972.bugfix.rst ================================================ Fixed false-positive :py:class:`DeprecationWarning` for passing ``enable_cleanup_closed=True`` to :py:class:`~aiohttp.TCPConnector` specifically on Python 3.12.7. -- by :user:`Robsdedude`. ================================================ FILE: CHANGES/11989.feature.rst ================================================ Added explicit APIs for bytes-returning JSON serializer: ``JSONBytesEncoder`` type, ``JsonBytesPayload``, :func:`~aiohttp.web.json_bytes_response`, :meth:`~aiohttp.web.WebSocketResponse.send_json_bytes` and :meth:`~aiohttp.ClientWebSocketResponse.send_json_bytes` methods, and ``json_serialize_bytes`` parameter for :class:`~aiohttp.ClientSession` -- by :user:`kevinpark1217`. ================================================ FILE: CHANGES/11992.contrib.rst ================================================ Fixed flaky performance tests by using appropriate fixed thresholds that account for CI variability -- by :user:`rodrigobnogueira`. ================================================ FILE: CHANGES/12027.misc.rst ================================================ Fixed ``test_invalid_idna`` to work with ``idna`` 3.11 by using an invalid character (``\u0080``) that is rejected by ``yarl`` during URL construction -- by :user:`rodrigobnogueira`. ================================================ FILE: CHANGES/12030.bugfix.rst ================================================ Reset the WebSocket heartbeat timer on inbound data to avoid false ping/pong timeouts while receiving large frames -- by :user:`hoffmang9`. ================================================ FILE: CHANGES/12042.doc.rst ================================================ Documented :exc:`asyncio.TimeoutError` for ``WebSocketResponse.receive()`` and related methods -- by :user:`veeceey`. ================================================ FILE: CHANGES/12069.packaging.rst ================================================ Upgraded llhttp to 3.9.1 -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/12088.bugfix.rst ================================================ Fixed tests to pass when run after 2027-05-31 -- by :user:`bmwiedemann`. ================================================ FILE: CHANGES/12091.bugfix.rst ================================================ Switched :py:meth:`~aiohttp.CookieJar.save` to use JSON format and :py:meth:`~aiohttp.CookieJar.load` to try JSON first with a fallback to a restricted pickle unpickler that only allows cookie-related types (``SimpleCookie``, ``Morsel``, ``defaultdict``, etc.), preventing arbitrary code execution via malicious pickle payloads (CWE-502) -- by :user:`YuvalElbar6`. ================================================ FILE: CHANGES/12096.bugfix.rst ================================================ Fixed _sendfile_fallback over-reading beyond requested count -- by :user:`bysiber`. ================================================ FILE: CHANGES/12097.bugfix.rst ================================================ Fixed digest auth dropping challenge fields with empty string values -- by :user:`bysiber`. ================================================ FILE: CHANGES/12106.feature.rst ================================================ Added a ``dns_cache_max_size`` parameter to ``TCPConnector`` to limit the size of the cache -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/12136.bugfix.rst ================================================ ``ClientConnectorCertificateError.os_error`` no longer raises :exc:`AttributeError` -- by :user:`themylogin`. ================================================ FILE: CHANGES/12170.misc.rst ================================================ Fixed race condition in ``test_data_file`` on Python 3.14 free-threaded builds -- by :user:`rodrigobnogueira`. ================================================ FILE: CHANGES/12173.contrib.rst ================================================ Fixed and reworked ``autobahn`` tests -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/12195.bugfix.rst ================================================ Fixed redirects with consumed non-rewindable request bodies to raise :class:`aiohttp.ClientPayloadError` instead of silently sending an empty body. ================================================ FILE: CHANGES/12231.bugfix.rst ================================================ Adjusted pure-Python request header value validation to align with RFC 9110 control-character handling, while preserving lax response parser behavior, and added regression tests for Host/header control-character cases. -- by :user:`rodrigobnogueira`. ================================================ FILE: CHANGES/12240.bugfix.rst ================================================ Rejected duplicate singleton headers (``Host``, ``Content-Type``, ``Content-Length``, etc.) in the C extension HTTP parser to match the pure Python parser behavior, preventing potential host-based access control bypasses via parser differentials -- by :user:`rodrigobnogueira`. ================================================ FILE: CHANGES/12249.bugfix.rst ================================================ Aligned the pure-Python HTTP request parser with the C parser by splitting comma-separated and repeated ``Connection`` header values for keep-alive, close, and upgrade handling -- by :user:`rodrigobnogueira`. ================================================ FILE: CHANGES/2174.bugfix ================================================ Raise 400 Bad Request on server-side `await request.json()` if incorrect content-type received. ================================================ FILE: CHANGES/2835.breaking.rst ================================================ Drop lowercased enum items of ``WSMsgType`` (text, binary, ...), use uppercased items instead (TEXT, BINARY, ...). ================================================ FILE: CHANGES/2977.breaking.rst ================================================ Drop aiodns<1.1 support. ================================================ FILE: CHANGES/3310.bugfix ================================================ Docs clarification that aiohttp client does not support HTTP Pipelining. ================================================ FILE: CHANGES/3462.feature ================================================ ``web.HTTPException`` and derived classes are not inherited from ``web.Response`` anymore. ================================================ FILE: CHANGES/3463.breaking.rst ================================================ Make ``ClientSession`` slot-based class, convert debug-mode warning about a wild session modification into a strict error. ================================================ FILE: CHANGES/3482.bugfix ================================================ Do not return `None` on `await response.json()` when body is empty. Instead, raise `json.JSONDecodeError` as expected. ================================================ FILE: CHANGES/3538.breaking.rst ================================================ Drop ``@aiohttp.streamer`` decorator, use async generators instead. ================================================ FILE: CHANGES/3539.breaking.rst ================================================ Disallow creation of aiohttp objects (``ClientSession``, ``Connector`` etc.) without running event loop. ================================================ FILE: CHANGES/3540.feature ================================================ Make sanity check for web-handler return value working in release mode ================================================ FILE: CHANGES/3542.breaking.rst ================================================ Setting ``web.Application`` custom attributes is now forbidden ================================================ FILE: CHANGES/3545.feature ================================================ Drop custom router support ================================================ FILE: CHANGES/3547.breaking.rst ================================================ Remove deprecated resp.url_obj ================================================ FILE: CHANGES/3548.breaking.rst ================================================ Drop deprecated SSL client settings. ================================================ FILE: CHANGES/3559.doc ================================================ Clarified ``WebSocketResponse`` closure in the quick start example. ================================================ FILE: CHANGES/3562.bugfix ================================================ Raise ``web_exceptions.HTTPUnsupportedMediaType`` when invalid `Content-Type` encoding passed. ================================================ FILE: CHANGES/3569.feature ================================================ Make new style middleware default, deprecate the @middleware decorator and remove support for old-style middleware. ================================================ FILE: CHANGES/3580.breaking.rst ================================================ Drop explicit loop. Use ``asyncio.get_event_loop()`` instead if the loop instance is needed. All aiohttp objects work with the currently running loop, a creation of aiohttp instances, e.g. ClientSession when the loop is not running is forbidden. As a side effect of PR passing callables to ``aiohttp_server()`` and ``aiohttp_client()`` pytest fixtures are forbidden, please call these callables explicitly. ================================================ FILE: CHANGES/3612.bugfix ================================================ Fixed a grammatical error in documentation ================================================ FILE: CHANGES/3613.bugfix ================================================ Use sanitized URL as Location header in redirects ================================================ FILE: CHANGES/3642.doc ================================================ Modify documentation for Resolvers to make it clear that asynchronous resolver is not used by default when aiodns is installed. ================================================ FILE: CHANGES/3685.doc ================================================ Add documentation regarding creating and destroying persistent session. ================================================ FILE: CHANGES/3721.bugfix ================================================ Add the missing `TestClient.scheme` property. ================================================ FILE: CHANGES/3767.feature ================================================ Add ``AbstractAsyncAccessLogger`` to allow IO while logging. ================================================ FILE: CHANGES/3787.feature ================================================ Added ability to use contextvars in logger ================================================ FILE: CHANGES/3796.feature ================================================ Add a debug argument to `web.run_app()` for enabling debug mode on loop. ================================================ FILE: CHANGES/3890.breaking.rst ================================================ Drop deprecated `read_timeout` and `conn_timeout` in `ClientSession` constructor, please use `timeout` argument instead. ================================================ FILE: CHANGES/3901.breaking.rst ================================================ Drop sync context managers that raises ``TypeError`` already. ================================================ FILE: CHANGES/3929.breaking.rst ================================================ Drop processing sync web-handlers (deprecated since aiohttp 3.0) ================================================ FILE: CHANGES/3931.breaking.rst ================================================ Drop deprecated ``BaseRequest.message``, ``BaseRequest.loop``, ``BaseRequest.has_body`` ================================================ FILE: CHANGES/3932.breaking.rst ================================================ Drop deprecated ``unused_port``, ``test_server``, ``raw_test_server`` and ``test_client`` pytest fixtures. ================================================ FILE: CHANGES/3933.breaking.rst ================================================ Forbid inheritance from ``ClientSession`` and ``Application`` ================================================ FILE: CHANGES/3934.breaking.rst ================================================ Drop deprecated ``ClientResponseError.code`` attribute ================================================ FILE: CHANGES/3935.breaking.rst ================================================ Drop deprecated ``ClientSession.loop`` and ``Connection.loop``. Forbid changing ``ClientSession.requote_redirect_url``. ================================================ FILE: CHANGES/3939.breaking.rst ================================================ Drop deprecated ``Application.make_handler()`` ================================================ FILE: CHANGES/3940.breaking.rst ================================================ Drop HTTP chunk size from client and server, remove deprecated `response.output_length`. ================================================ FILE: CHANGES/3942.breaking.rst ================================================ Make `web.BaseRequest`, `web.Request`, `web.StreamResponse`, `web.Response` and `web.WebSocketResponse` slot-based, prevent custom instance attributes. ================================================ FILE: CHANGES/3948.breaking.rst ================================================ Forbid changing frozen app properties. ================================================ FILE: CHANGES/3994.misc ================================================ correct the names of some functions in ``tests/test_client_functional.py`` ================================================ FILE: CHANGES/4161.doc ================================================ Update contributing guide so new contributors can successfully install dependencies ================================================ FILE: CHANGES/4277.feature ================================================ Added ``set_cookie`` and ``del_cookie`` methods to ``HTTPException`` ================================================ FILE: CHANGES/4283.bugfix ================================================ Fix incorrect code in example ================================================ FILE: CHANGES/4299.bugfix ================================================ Delete older code in example (:file:`examples/web_classview.py`) ================================================ FILE: CHANGES/4302.bugfix ================================================ Fixed the support of route handlers wrapped by :py:func:`functools.partial` ================================================ FILE: CHANGES/4368.bugfix ================================================ Make `web.BaseRequest`, `web.Request`, `web.StreamResponse`, `web.Response` and `web.WebSocketResponse` weak referenceable again. ================================================ FILE: CHANGES/4452.doc ================================================ Fixed a typo in the ``client_quickstart`` doc. ================================================ FILE: CHANGES/4504.doc ================================================ Updated the contribution guide to reflect the automatic thread locking policy. ================================================ FILE: CHANGES/4526.bugfix ================================================ Ignore protocol exceptions after it is closed. ================================================ FILE: CHANGES/4558.bugfix ================================================ Fixed body_size comparison to client_max_size for web request. ================================================ FILE: CHANGES/4656.bugfix ================================================ Propagate all warnings captured in coroutine test functions to pytest. ================================================ FILE: CHANGES/4695.doc ================================================ Added documentation on how to patch unittest cases with decorator for python < 3.8 ================================================ FILE: CHANGES/4706.feature ================================================ Add a fixture ``aiohttp_client_cls`` that allows usage of ``aiohttp.test_utils.TestClient`` custom implementations in tests. ================================================ FILE: CHANGES/5075.feature ================================================ Multidict > 5 is now supported ================================================ FILE: CHANGES/5191.doc ================================================ Add pytest-aiohttp-client library to third party usage list ================================================ FILE: CHANGES/5258.bugfix ================================================ Fixed github workflow `update-pre-commit` on forks, since this workflow should run only in the main repository and also because it was giving failed jobs on all the forks. Now it will show up as skipped workflow. ================================================ FILE: CHANGES/5278.breaking.rst ================================================ Drop Python 3.6 support ================================================ FILE: CHANGES/5284.breaking.rst ================================================ ``attrs`` library was replaced with ``dataclasses``. Replace ``attr.evolve()`` with ``dataclasses.replace()`` if needed. ================================================ FILE: CHANGES/5284.feature ================================================ Use ``dataclasses`` instead of ``attrs`` for ``ClientTimeout``, client signals, and other few internal structures. ================================================ FILE: CHANGES/5287.feature ================================================ Before ``sentinel`` was processed as either ``object`` or ``Any``, both variants are far from perfectness. Now ``sentinel`` has a dedicated type which is not equal to anything. ================================================ FILE: CHANGES/5516.misc ================================================ Removed @unittest_run_loop. This is now the default behaviour. ================================================ FILE: CHANGES/5533.misc ================================================ Add regression test for 0 timeouts. ================================================ FILE: CHANGES/5558.bugfix ================================================ Add parsing boundary from Content-Type header while making POST request ================================================ FILE: CHANGES/5634.feature ================================================ A warning was added, when a cookie's length exceeds the :rfc:`6265` minimum client support -- :user:`anesabml`. ================================================ FILE: CHANGES/5783.feature ================================================ Started keeping the ``Authorization`` header during HTTP -> HTTPS redirects when the host remains the same. ================================================ FILE: CHANGES/5806.misc ================================================ Remove last remnants of attrs library. ================================================ FILE: CHANGES/5829.misc ================================================ Disallow untyped defs on internal tests. ================================================ FILE: CHANGES/5870.misc ================================================ Simplify generator expression. ================================================ FILE: CHANGES/5894.bugfix ================================================ Fix JSON media type suffix matching with main types other than application. ================================================ FILE: CHANGES/6180.bugfix ================================================ Fixed matching the JSON media type to not accept arbitrary characters after ``application/json`` or the ``+json`` media type suffix. ================================================ FILE: CHANGES/6181.bugfix ================================================ Make JSON media type matching case insensitive per RFC 2045. ================================================ FILE: CHANGES/6193.feature ================================================ Bump async-timeout to >=4.0 ================================================ FILE: CHANGES/6547.bugfix ================================================ Remove overlapping slots in ``RequestHandler``, fix broken slots inheritance in :py:class:`~aiohttp.web.StreamResponse`. ================================================ FILE: CHANGES/6721.misc ================================================ Remove unused argument `max_headers` of HeadersParser. ================================================ FILE: CHANGES/6979.doc ================================================ Improve grammar and brevity in communication in the Policy for Backward Incompatible Changes section of ``docs/index.rst`` -- :user:`Paarth`. ================================================ FILE: CHANGES/6998.doc ================================================ Added documentation on client authentication and updating headers. -- by :user:`faph` ================================================ FILE: CHANGES/7107.breaking.rst ================================================ Removed deprecated ``.loop``, ``.setUpAsync()``, ``.tearDownAsync()`` and ``.get_app()`` from ``AioHTTPTestCase``. ================================================ FILE: CHANGES/7265.breaking.rst ================================================ Deleted ``size`` arg from ``StreamReader.feed_data`` -- by :user:`DavidRomanovizc`. ================================================ FILE: CHANGES/7319.feature.rst ================================================ Changed ``WSMessage`` to a tagged union of ``NamedTuple`` -- by :user:`Dreamsorcerer`. This change allows type checkers to know the precise type of ``data`` after checking the ``type`` attribute. If accessing messages by tuple indexes, the order has now changed. Code such as: ``typ, data, extra = ws_message`` will need to be changed to: ``data, extra, typ = ws_message`` No changes are needed if accessing by attribute name. ================================================ FILE: CHANGES/7677.bugfix ================================================ Changed ``AppKey`` warning to ``web.NotAppKeyWarning`` and stop it being displayed by default. -- by :user:`Dreamsorcerer` ================================================ FILE: CHANGES/7772.bugfix ================================================ Fix CONNECT always being treated as having an empty body ================================================ FILE: CHANGES/7815.bugfix ================================================ Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer` ================================================ FILE: CHANGES/8048.breaking.rst ================================================ Removed deprecated support for `ssl=None` -- by :user:`Dreamsorcerer` ================================================ FILE: CHANGES/8139.contrib.rst ================================================ Two definitions for "test_invalid_route_name" existed, only one was being run. Refactored them into a single parameterized test. Enabled lint rule to prevent regression. -- by :user:`alexmac`. ================================================ FILE: CHANGES/8197.doc ================================================ Fixed false behavior of base_url param for ClientSession in client documentation -- by :user:`alexis974`. ================================================ FILE: CHANGES/8303.breaking.rst ================================================ Removed ``content_transfer_encoding`` parameter in :py:meth:`FormData.add_field() ` and passing bytes no longer creates a file field unless the ``filename`` parameter is used -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/8596.breaking.rst ================================================ Removed old async compatibility from ``ClientResponse.release()`` -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/8698.breaking.rst ================================================ Changed signature of ``content_disposition_header()`` so ``params`` is now passed as a dict, in order to reduce typing errors -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/8957.breaking.rst ================================================ Removed ``version`` parameter from ``.set_cookie()`` (this shouldn't exist in cookies today) -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/9109.breaking.rst ================================================ Changed default value to ``compress`` from ``None`` to ``False`` (``None`` is no longer an expected value) -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/9212.packaging.rst ================================================ Removed remaining `make_mocked_coro` in the test suite -- by :user:`polkapolka`. ================================================ FILE: CHANGES/9254.breaking.rst ================================================ Stopped allowing use of ``ClientResponse.text()``/``ClientResponse.json()`` after leaving ``async with`` context. This now matches the behaviour of ``ClientResponse.read()`` -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/9292.breaking.rst ================================================ Started rejecting non string values in `FormData`, to avoid unexpected results -- by :user:`Dreamsorcerer`. ================================================ FILE: CHANGES/9413.misc.rst ================================================ Reduced memory required many small objects by adding ``__slots__`` to dataclasses -- by :user:`bdraco`. ================================================ FILE: CHANGES/README.rst ================================================ .. _Adding change notes with your PRs: Adding change notes with your PRs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ It is very important to maintain a log for news of how updating to the new version of the software will affect end-users. This is why we enforce collection of the change fragment files in pull requests as per `Towncrier philosophy`_. The idea is that when somebody makes a change, they must record the bits that would affect end-users, only including information that would be useful to them. Then, when the maintainers publish a new release, they'll automatically use these records to compose a change log for the respective version. It is important to understand that including unnecessary low-level implementation related details generates noise that is not particularly useful to the end-users most of the time. And so such details should be recorded in the Git history rather than a changelog. Alright! So how to add a news fragment? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ``aiohttp`` uses `towncrier `_ for changelog management. To submit a change note about your PR, add a text file into the ``CHANGES/`` folder. It should contain an explanation of what applying this PR will change in the way end-users interact with the project. One sentence is usually enough but feel free to add as many details as you feel necessary for the users to understand what it means. **Use the past tense** for the text in your fragment because, combined with others, it will be a part of the "news digest" telling the readers **what changed** in a specific version of the library *since the previous version*. You should also use *reStructuredText* syntax for highlighting code (inline or block), linking parts of the docs or external sites. However, you do not need to reference the issue or PR numbers here as *towncrier* will automatically add a reference to all of the affected issues when rendering the news file. If you wish to sign your change, feel free to add ``-- by :user:`github-username``` at the end (replace ``github-username`` with your own!). Finally, name your file following the convention that Towncrier understands: it should start with the number of an issue or a PR followed by a dot, then add a patch type, like ``feature``, ``doc``, ``contrib`` etc., and add ``.rst`` as a suffix. If you need to add more than one fragment, you may add an optional sequence number (delimited with another period) between the type and the suffix. In general the name will follow ``..rst`` pattern, where the categories are: - ``bugfix``: A bug fix for something we deemed an improper undesired behavior that got corrected in the release to match pre-agreed expectations. - ``feature``: A new behavior, public APIs. That sort of stuff. - ``deprecation``: A declaration of future API removals and breaking changes in behavior. - ``breaking``: When something public gets removed in a breaking way. Could be deprecated in an earlier release. - ``doc``: Notable updates to the documentation structure or build process. - ``packaging``: Notes for downstreams about unobvious side effects and tooling. Changes in the test invocation considerations and runtime assumptions. - ``contrib``: Stuff that affects the contributor experience. e.g. Running tests, building the docs, setting up the development environment. - ``misc``: Changes that are hard to assign to any of the above categories. A pull request may have more than one of these components, for example a code change may introduce a new feature that deprecates an old feature, in which case two fragments should be added. It is not necessary to make a separate documentation fragment for documentation changes accompanying the relevant code changes. Examples for adding changelog entries to your Pull Requests ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File :file:`CHANGES/6045.doc.1.rst`: .. code-block:: rst Added a ``:user:`` role to Sphinx config -- by :user:`webknjaz`. File :file:`CHANGES/8074.bugfix.rst`: .. code-block:: rst Fixed an unhandled exception in the Python HTTP parser on header lines starting with a colon -- by :user:`pajod`. Invalid request lines with anything but a dot between the HTTP major and minor version are now rejected. Invalid header field names containing question mark or slash are now rejected. Such requests are incompatible with :rfc:`9110#section-5.6.2` and are not known to be of any legitimate use. File :file:`CHANGES/4594.feature.rst`: .. code-block:: rst Added support for ``ETag`` to :py:class:`~aiohttp.web.FileResponse` -- by :user:`greshilov`, :user:`serhiy-storchaka` and :user:`asvetlov`. .. tip:: See :file:`pyproject.toml` for all available categories (``tool.towncrier.type``). .. _Towncrier philosophy: https://towncrier.readthedocs.io/en/stable/#philosophy ================================================ FILE: CHANGES.rst ================================================ .. You should *NOT* be adding new change log entries to this file, this file is managed by towncrier. You *may* edit previous change logs to fix problems like typo corrections or such. To add a new change log entry, please see https://pip.pypa.io/en/latest/development/#adding-a-news-entry we named the news folder "changes". WARNING: Don't drop the next directive! .. towncrier release notes start 3.13.3 (2026-01-03) =================== This release contains fixes for several vulnerabilities. It is advised to upgrade as soon as possible. Bug fixes --------- - Fixed proxy authorization headers not being passed when reusing a connection, which caused 407 (Proxy authentication required) errors -- by :user:`GLeurquin`. *Related issues and pull requests on GitHub:* :issue:`2596`. - Fixed multipart reading failing when encountering an empty body part -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`11857`. - Fixed a case where the parser wasn't raising an exception for a websocket continuation frame when there was no initial frame in context. *Related issues and pull requests on GitHub:* :issue:`11862`. Removals and backward incompatible breaking changes --------------------------------------------------- - ``Brotli`` and ``brotlicffi`` minimum version is now 1.2. Decompression now has a default maximum output size of 32MiB per decompress call -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`11898`. Packaging updates and notes for downstreams ------------------------------------------- - Moved dependency metadata from :file:`setup.cfg` to :file:`pyproject.toml` per :pep:`621` -- by :user:`cdce8p`. *Related issues and pull requests on GitHub:* :issue:`11643`. Contributor-facing changes -------------------------- - Removed unused ``update-pre-commit`` github action workflow -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`11689`. Miscellaneous internal changes ------------------------------ - Optimized web server performance when access logging is disabled by reducing time syscalls -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10713`. - Added regression test for cached logging status -- by :user:`meehand`. *Related issues and pull requests on GitHub:* :issue:`11778`. ---- 3.13.2 (2025-10-28) =================== Bug fixes --------- - Fixed cookie parser to continue parsing subsequent cookies when encountering a malformed cookie that fails regex validation, such as Google's ``g_state`` cookie with unescaped quotes -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11632`. - Fixed loading netrc credentials from the default :file:`~/.netrc` (:file:`~/_netrc` on Windows) location when the :envvar:`NETRC` environment variable is not set -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11713`, :issue:`11714`. - Fixed WebSocket compressed sends to be cancellation safe. Tasks are now shielded during compression to prevent compressor state corruption. This ensures that the stateful compressor remains consistent even when send operations are cancelled -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11725`. ---- 3.13.1 (2025-10-17) =================== Features -------- - Make configuration options in ``AppRunner`` also available in ``run_app()`` -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`11633`. Bug fixes --------- - Switched to `backports.zstd` for Python <3.14 and fixed zstd decompression for chunked zstd streams -- by :user:`ZhaoMJ`. Note: Users who installed ``zstandard`` for support on Python <3.14 will now need to install ``backports.zstd`` instead (installing ``aiohttp[speedups]`` will do this automatically). *Related issues and pull requests on GitHub:* :issue:`11623`. - Updated ``Content-Type`` header parsing to return ``application/octet-stream`` when header contains invalid syntax. See :rfc:`9110#section-8.3-5`. -- by :user:`sgaist`. *Related issues and pull requests on GitHub:* :issue:`10889`. - Fixed Python 3.14 support when built without ``zstd`` support -- by :user:`JacobHenner`. *Related issues and pull requests on GitHub:* :issue:`11603`. - Fixed blocking I/O in the event loop when using netrc authentication by moving netrc file lookup to an executor -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11634`. - Fixed routing to a sub-application added via ``.add_domain()`` not working if the same path exists on the parent app. -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`11673`. Packaging updates and notes for downstreams ------------------------------------------- - Moved core packaging metadata from :file:`setup.cfg` to :file:`pyproject.toml` per :pep:`621` -- by :user:`cdce8p`. *Related issues and pull requests on GitHub:* :issue:`9951`. ---- 3.13.0 (2025-10-06) =================== Features -------- - Added support for Python 3.14. *Related issues and pull requests on GitHub:* :issue:`10851`, :issue:`10872`. - Added support for free-threading in Python 3.14+ -- by :user:`kumaraditya303`. *Related issues and pull requests on GitHub:* :issue:`11466`, :issue:`11464`. - Added support for Zstandard (aka Zstd) compression -- by :user:`KGuillaume-chaps`. *Related issues and pull requests on GitHub:* :issue:`11161`. - Added ``StreamReader.total_raw_bytes`` to check the number of bytes downloaded -- by :user:`robpats`. *Related issues and pull requests on GitHub:* :issue:`11483`. Bug fixes --------- - Fixed pytest plugin to not use deprecated :py:mod:`asyncio` policy APIs. *Related issues and pull requests on GitHub:* :issue:`10851`. - Updated `Content-Disposition` header parsing to handle trailing semicolons and empty parts -- by :user:`PLPeeters`. *Related issues and pull requests on GitHub:* :issue:`11243`. - Fixed saved ``CookieJar`` failing to be loaded if cookies have ``partitioned`` flag when ``http.cookie`` does not have partitioned cookies supports. -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`11523`. Improved documentation ---------------------- - Added ``Wireup`` to third-party libraries -- by :user:`maldoinc`. *Related issues and pull requests on GitHub:* :issue:`11233`. Packaging updates and notes for downstreams ------------------------------------------- - The `blockbuster` test dependency is now optional; the corresponding test fixture is disabled when it is unavailable -- by :user:`musicinybrain`. *Related issues and pull requests on GitHub:* :issue:`11363`. - Added ``riscv64`` build to releases -- by :user:`eshattow`. *Related issues and pull requests on GitHub:* :issue:`11425`. Contributor-facing changes -------------------------- - Fixed ``test_send_compress_text`` failing when alternative zlib implementation is used. (``zlib-ng`` in python 3.14 windows build) -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`11546`. ---- 3.12.15 (2025-07-28) ==================== Bug fixes --------- - Fixed :class:`~aiohttp.DigestAuthMiddleware` to preserve the algorithm case from the server's challenge in the authorization response. This improves compatibility with servers that perform case-sensitive algorithm matching (e.g., servers expecting ``algorithm=MD5-sess`` instead of ``algorithm=MD5-SESS``) -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11352`. Improved documentation ---------------------- - Remove outdated contents of ``aiohttp-devtools`` and ``aiohttp-swagger`` from Web_advanced docs. -- by :user:`Cycloctane` *Related issues and pull requests on GitHub:* :issue:`11347`. Packaging updates and notes for downstreams ------------------------------------------- - Started including the ``llhttp`` :file:`LICENSE` file in wheels by adding ``vendor/llhttp/LICENSE`` to ``license-files`` in :file:`setup.cfg` -- by :user:`threexc`. *Related issues and pull requests on GitHub:* :issue:`11226`. Contributor-facing changes -------------------------- - Updated a regex in `test_aiohttp_request_coroutine` for Python 3.14. *Related issues and pull requests on GitHub:* :issue:`11271`. ---- 3.12.14 (2025-07-10) ==================== Bug fixes --------- - Fixed file uploads failing with HTTP 422 errors when encountering 307/308 redirects, and 301/302 redirects for non-POST methods, by preserving the request body when appropriate per :rfc:`9110#section-15.4.3-3.1` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11270`. - Fixed :py:meth:`ClientSession.close() ` hanging indefinitely when using HTTPS requests through HTTP proxies -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11273`. - Bumped minimum version of aiosignal to 1.4+ to resolve typing issues -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`11280`. Features -------- - Added initial trailer parsing logic to Python HTTP parser -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`11269`. Improved documentation ---------------------- - Clarified exceptions raised by ``WebSocketResponse.send_frame`` et al. -- by :user:`DoctorJohn`. *Related issues and pull requests on GitHub:* :issue:`11234`. ---- 3.12.13 (2025-06-14) ==================== Bug fixes --------- - Fixed auto-created :py:class:`~aiohttp.TCPConnector` not using the session's event loop when :py:class:`~aiohttp.ClientSession` is created without an explicit connector -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11147`. ---- 3.12.12 (2025-06-09) ==================== Bug fixes --------- - Fixed cookie unquoting to properly handle octal escape sequences in cookie values (e.g., ``\012`` for newline) by vendoring the correct ``_unquote`` implementation from Python's ``http.cookies`` module -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11173`. - Fixed ``Cookie`` header parsing to treat attribute names as regular cookies per :rfc:`6265#section-5.4` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11178`. ---- 3.12.11 (2025-06-07) ==================== Features -------- - Improved SSL connection handling by changing the default ``ssl_shutdown_timeout`` from ``0.1`` to ``0`` seconds. SSL connections now use Python's default graceful shutdown during normal operation but are aborted immediately when the connector is closed, providing optimal behavior for both cases. Also added support for ``ssl_shutdown_timeout=0`` on all Python versions. Previously, this value was rejected on Python 3.11+ and ignored on earlier versions. Non-zero values on Python < 3.11 now trigger a ``RuntimeWarning`` -- by :user:`bdraco`. The ``ssl_shutdown_timeout`` parameter is now deprecated and will be removed in aiohttp 4.0 as there is no clear use case for changing the default. *Related issues and pull requests on GitHub:* :issue:`11148`. Deprecations (removal in next major release) -------------------------------------------- - Improved SSL connection handling by changing the default ``ssl_shutdown_timeout`` from ``0.1`` to ``0`` seconds. SSL connections now use Python's default graceful shutdown during normal operation but are aborted immediately when the connector is closed, providing optimal behavior for both cases. Also added support for ``ssl_shutdown_timeout=0`` on all Python versions. Previously, this value was rejected on Python 3.11+ and ignored on earlier versions. Non-zero values on Python < 3.11 now trigger a ``RuntimeWarning`` -- by :user:`bdraco`. The ``ssl_shutdown_timeout`` parameter is now deprecated and will be removed in aiohttp 4.0 as there is no clear use case for changing the default. *Related issues and pull requests on GitHub:* :issue:`11148`. ---- 3.12.10 (2025-06-07) ==================== Bug fixes --------- - Fixed leak of ``aiodns.DNSResolver`` when :py:class:`~aiohttp.TCPConnector` is closed and no resolver was passed when creating the connector -- by :user:`Tasssadar`. This was a regression introduced in version 3.12.0 (:pr:`10897`). *Related issues and pull requests on GitHub:* :issue:`11150`. ---- 3.12.9 (2025-06-04) =================== Bug fixes --------- - Fixed ``IOBasePayload`` and ``TextIOPayload`` reading entire files into memory when streaming large files -- by :user:`bdraco`. When using file-like objects with the aiohttp client, the entire file would be read into memory if the file size was provided in the ``Content-Length`` header. This could cause out-of-memory errors when uploading large files. The payload classes now correctly read data in chunks of ``READ_SIZE`` (64KB) regardless of the total content length. *Related issues and pull requests on GitHub:* :issue:`11138`. ---- 3.12.8 (2025-06-04) =================== Features -------- - Added preemptive digest authentication to :class:`~aiohttp.DigestAuthMiddleware` -- by :user:`bdraco`. The middleware now reuses authentication credentials for subsequent requests to the same protection space, improving efficiency by avoiding extra authentication round trips. This behavior matches how web browsers handle digest authentication and follows :rfc:`7616#section-3.6`. Preemptive authentication is enabled by default but can be disabled by passing ``preemptive=False`` to the middleware constructor. *Related issues and pull requests on GitHub:* :issue:`11128`, :issue:`11129`. ---- 3.12.7 (2025-06-02) =================== .. warning:: This release fixes an issue where the ``quote_cookie`` parameter was not being properly respected for shared cookies (domain="", path=""). If your server does not handle quoted cookies correctly, you may need to disable cookie quoting by setting ``quote_cookie=False`` when creating your :class:`~aiohttp.ClientSession` or :class:`~aiohttp.CookieJar`. See :ref:`aiohttp-client-cookie-quoting-routine` for details. Bug fixes --------- - Fixed cookie parsing to be more lenient when handling cookies with special characters in names or values. Cookies with characters like ``{``, ``}``, and ``/`` in names are now accepted instead of causing a :exc:`~http.cookies.CookieError` and 500 errors. Additionally, cookies with mismatched quotes in values are now parsed correctly, and quoted cookie values are now handled consistently whether or not they include special attributes like ``Domain``. Also fixed :class:`~aiohttp.CookieJar` to ensure shared cookies (domain="", path="") respect the ``quote_cookie`` parameter, making cookie quoting behavior consistent for all cookies -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`2683`, :issue:`5397`, :issue:`7993`, :issue:`11112`. - Fixed an issue where cookies with duplicate names but different domains or paths were lost when updating the cookie jar. The :class:`~aiohttp.ClientSession` cookie jar now correctly stores all cookies even if they have the same name but different domain or path, following the :rfc:`6265#section-5.3` storage model -- by :user:`bdraco`. Note that :attr:`ClientResponse.cookies ` returns a :class:`~http.cookies.SimpleCookie` which uses the cookie name as a key, so only the last cookie with each name is accessible via this interface. All cookies can be accessed via :meth:`ClientResponse.headers.getall('Set-Cookie') ` if needed. *Related issues and pull requests on GitHub:* :issue:`4486`, :issue:`11105`, :issue:`11106`. Miscellaneous internal changes ------------------------------ - Avoided creating closed futures in ``ResponseHandler`` that will never be awaited -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11107`. - Downgraded the logging level for connector close errors from ERROR to DEBUG, as these are expected behavior with TLS 1.3 connections -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11114`. ---- 3.12.6 (2025-05-31) =================== Bug fixes --------- - Fixed spurious "Future exception was never retrieved" warnings for connection lost errors when the connector is not closed -- by :user:`bdraco`. When connections are lost, the exception is now marked as retrieved since it is always propagated through other means, preventing unnecessary warnings in logs. *Related issues and pull requests on GitHub:* :issue:`11100`. ---- 3.12.5 (2025-05-30) =================== Features -------- - Added ``ssl_shutdown_timeout`` parameter to :py:class:`~aiohttp.ClientSession` and :py:class:`~aiohttp.TCPConnector` to control the grace period for SSL shutdown handshake on TLS connections. This helps prevent "connection reset" errors on the server side while avoiding excessive delays during connector cleanup. Note: This parameter only takes effect on Python 3.11+ -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11091`, :issue:`11094`. Miscellaneous internal changes ------------------------------ - Improved performance of isinstance checks by using collections.abc types instead of typing module equivalents -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`11085`, :issue:`11088`. ---- 3.12.4 (2025-05-28) =================== Bug fixes --------- - Fixed connector not waiting for connections to close before returning from :meth:`~aiohttp.BaseConnector.close` (partial backport of :pr:`3733`) -- by :user:`atemate` and :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`1925`, :issue:`11074`. ---- 3.12.3 (2025-05-28) =================== Bug fixes --------- - Fixed memory leak in :py:meth:`~aiohttp.CookieJar.filter_cookies` that caused unbounded memory growth when making requests to different URL paths -- by :user:`bdraco` and :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`11052`, :issue:`11054`. ---- 3.12.2 (2025-05-26) =================== Bug fixes --------- - Fixed ``Content-Length`` header not being set to ``0`` for non-GET requests with ``None`` body -- by :user:`bdraco`. Non-GET requests (``POST``, ``PUT``, ``PATCH``, ``DELETE``) with ``None`` as the body now correctly set the ``Content-Length`` header to ``0``, matching the behavior of requests with empty bytes (``b""``). This regression was introduced in aiohttp 3.12.1. *Related issues and pull requests on GitHub:* :issue:`11035`. ---- 3.12.1 (2025-05-26) =================== Features -------- - Added support for reusable request bodies to enable retries, redirects, and digest authentication -- by :user:`bdraco` and :user:`GLGDLY`. Most payloads can now be safely reused multiple times, fixing long-standing issues where POST requests with form data or file uploads would fail on redirects with errors like "Form data has been processed already" or "I/O operation on closed file". This also enables digest authentication to work with request bodies and allows retry mechanisms to resend requests without consuming the payload. Note that payloads derived from async iterables may still not be reusable in some cases. *Related issues and pull requests on GitHub:* :issue:`5530`, :issue:`5577`, :issue:`9201`, :issue:`11017`. ---- 3.12.0 (2025-05-24) =================== Bug fixes --------- - Fixed :py:attr:`~aiohttp.web.WebSocketResponse.prepared` property to correctly reflect the prepared state, especially during timeout scenarios -- by :user:`bdraco` *Related issues and pull requests on GitHub:* :issue:`6009`, :issue:`10988`. - Response is now always True, instead of using MutableMapping behaviour (False when map is empty) *Related issues and pull requests on GitHub:* :issue:`10119`. - Fixed connection reuse for file-like data payloads by ensuring buffer truncation respects content-length boundaries and preventing premature connection closure race -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10325`, :issue:`10915`, :issue:`10941`, :issue:`10943`. - Fixed pytest plugin to not use deprecated :py:mod:`asyncio` policy APIs. *Related issues and pull requests on GitHub:* :issue:`10851`. - Fixed :py:class:`~aiohttp.resolver.AsyncResolver` not using the ``loop`` argument in versions 3.x where it should still be supported -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10951`. Features -------- - Added a comprehensive HTTP Digest Authentication client middleware (DigestAuthMiddleware) that implements RFC 7616. The middleware supports all standard hash algorithms (MD5, SHA, SHA-256, SHA-512) with session variants, handles both 'auth' and 'auth-int' quality of protection options, and automatically manages the authentication flow by intercepting 401 responses and retrying with proper credentials -- by :user:`feus4177`, :user:`TimMenninger`, and :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`2213`, :issue:`10725`. - Added client middleware support -- by :user:`bdraco` and :user:`Dreamsorcerer`. This change allows users to add middleware to the client session and requests, enabling features like authentication, logging, and request/response modification without modifying the core request logic. Additionally, the ``session`` attribute was added to ``ClientRequest``, allowing middleware to access the session for making additional requests. *Related issues and pull requests on GitHub:* :issue:`9732`, :issue:`10902`, :issue:`10945`, :issue:`10952`, :issue:`10959`, :issue:`10968`. - Allow user setting zlib compression backend -- by :user:`TimMenninger` This change allows the user to call :func:`aiohttp.set_zlib_backend()` with the zlib compression module of their choice. Default behavior continues to use the builtin ``zlib`` library. *Related issues and pull requests on GitHub:* :issue:`9798`. - Added support for overriding the base URL with an absolute one in client sessions -- by :user:`vivodi`. *Related issues and pull requests on GitHub:* :issue:`10074`. - Added ``host`` parameter to ``aiohttp_server`` fixture -- by :user:`christianwbrock`. *Related issues and pull requests on GitHub:* :issue:`10120`. - Detect blocking calls in coroutines using BlockBuster -- by :user:`cbornet`. *Related issues and pull requests on GitHub:* :issue:`10433`. - Added ``socket_factory`` to :py:class:`aiohttp.TCPConnector` to allow specifying custom socket options -- by :user:`TimMenninger`. *Related issues and pull requests on GitHub:* :issue:`10474`, :issue:`10520`, :issue:`10961`, :issue:`10962`. - Started building armv7l manylinux wheels -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10797`. - Implemented shared DNS resolver management to fix excessive resolver object creation when using multiple client sessions. The new ``_DNSResolverManager`` singleton ensures only one ``DNSResolver`` object is created for default configurations, significantly reducing resource usage and improving performance for applications using multiple client sessions simultaneously -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10847`, :issue:`10923`, :issue:`10946`. - Upgraded to LLHTTP 9.3.0 -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`10972`. - Optimized small HTTP requests/responses by coalescing headers and body into a single TCP packet -- by :user:`bdraco`. This change enhances network efficiency by reducing the number of packets sent for small HTTP payloads, improving latency and reducing overhead. Most importantly, this fixes compatibility with memory-constrained IoT devices that can only perform a single read operation and expect HTTP requests in one packet. The optimization uses zero-copy ``writelines`` when coalescing data and works with both regular and chunked transfer encoding. When ``aiohttp`` uses client middleware to communicate with an ``aiohttp`` server, connection reuse is more likely to occur since complete responses arrive in a single packet for small payloads. This aligns ``aiohttp`` with other popular HTTP clients that already coalesce small requests. *Related issues and pull requests on GitHub:* :issue:`10991`. Improved documentation ---------------------- - Improved documentation for middleware by adding warnings and examples about request body stream consumption. The documentation now clearly explains that request body streams can only be read once and provides best practices for sharing parsed request data between middleware and handlers -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`2914`. Packaging updates and notes for downstreams ------------------------------------------- - Removed non SPDX-license description from ``setup.cfg`` -- by :user:`devanshu-ziphq`. *Related issues and pull requests on GitHub:* :issue:`10662`. - Added support for building against system ``llhttp`` library -- by :user:`mgorny`. This change adds support for :envvar:`AIOHTTP_USE_SYSTEM_DEPS` environment variable that can be used to build aiohttp against the system install of the ``llhttp`` library rather than the vendored one. *Related issues and pull requests on GitHub:* :issue:`10759`. - ``aiodns`` is now installed on Windows with speedups extra -- by :user:`bdraco`. As of ``aiodns`` 3.3.0, ``SelectorEventLoop`` is no longer required when using ``pycares`` 4.7.0 or later. *Related issues and pull requests on GitHub:* :issue:`10823`. - Fixed compatibility issue with Cython 3.1.1 -- by :user:`bdraco` *Related issues and pull requests on GitHub:* :issue:`10877`. Contributor-facing changes -------------------------- - Sped up tests by disabling ``blockbuster`` fixture for ``test_static_file_huge`` and ``test_static_file_huge_cancel`` tests -- by :user:`dikos1337`. *Related issues and pull requests on GitHub:* :issue:`9705`, :issue:`10761`. - Updated tests to avoid using deprecated :py:mod:`asyncio` policy APIs and make it compatible with Python 3.14. *Related issues and pull requests on GitHub:* :issue:`10851`. - Added Winloop to test suite to support in the future -- by :user:`Vizonex`. *Related issues and pull requests on GitHub:* :issue:`10922`. Miscellaneous internal changes ------------------------------ - Added support for the ``partitioned`` attribute in the ``set_cookie`` method. *Related issues and pull requests on GitHub:* :issue:`9870`. - Setting :attr:`aiohttp.web.StreamResponse.last_modified` to an unsupported type will now raise :exc:`TypeError` instead of silently failing -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10146`. ---- 3.11.18 (2025-04-20) ==================== Bug fixes --------- - Disabled TLS in TLS warning (when using HTTPS proxies) for uvloop and newer Python versions -- by :user:`lezgomatt`. *Related issues and pull requests on GitHub:* :issue:`7686`. - Fixed reading fragmented WebSocket messages when the payload was masked -- by :user:`bdraco`. The problem first appeared in 3.11.17 *Related issues and pull requests on GitHub:* :issue:`10764`. ---- 3.11.17 (2025-04-19) ==================== Miscellaneous internal changes ------------------------------ - Optimized web server performance when access logging is disabled by reducing time syscalls -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10713`. - Improved web server performance when connection can be reused -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10714`. - Improved performance of the WebSocket reader -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10740`. - Improved performance of the WebSocket reader with large messages -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10744`. ---- 3.11.16 (2025-04-01) ==================== Bug fixes --------- - Replaced deprecated ``asyncio.iscoroutinefunction`` with its counterpart from ``inspect`` -- by :user:`layday`. *Related issues and pull requests on GitHub:* :issue:`10634`. - Fixed :class:`multidict.CIMultiDict` being mutated when passed to :class:`aiohttp.web.Response` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10672`. ---- 3.11.15 (2025-03-31) ==================== Bug fixes --------- - Reverted explicitly closing sockets if an exception is raised during ``create_connection`` -- by :user:`bdraco`. This change originally appeared in aiohttp 3.11.13 *Related issues and pull requests on GitHub:* :issue:`10464`, :issue:`10617`, :issue:`10656`. Miscellaneous internal changes ------------------------------ - Improved performance of WebSocket buffer handling -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10601`. - Improved performance of serializing headers -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10625`. ---- 3.11.14 (2025-03-16) ==================== Bug fixes --------- - Fixed an issue where dns queries were delayed indefinitely when an exception occurred in a ``trace.send_dns_cache_miss`` -- by :user:`logioniz`. *Related issues and pull requests on GitHub:* :issue:`10529`. - Fixed DNS resolution on platforms that don't support ``socket.AI_ADDRCONFIG`` -- by :user:`maxbachmann`. *Related issues and pull requests on GitHub:* :issue:`10542`. - The connector now raises :exc:`aiohttp.ClientConnectionError` instead of :exc:`OSError` when failing to explicitly close the socket after :py:meth:`asyncio.loop.create_connection` fails -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10551`. - Break cyclic references at connection close when there was a traceback -- by :user:`bdraco`. Special thanks to :user:`availov` for reporting the issue. *Related issues and pull requests on GitHub:* :issue:`10556`. - Break cyclic references when there is an exception handling a request -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10569`. Features -------- - Improved logging on non-overlapping WebSocket client protocols to include the remote address -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10564`. Miscellaneous internal changes ------------------------------ - Improved performance of parsing content types by adding a cache in the same manner currently done with mime types -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10552`. ---- 3.11.13 (2025-02-24) ==================== Bug fixes --------- - Removed a break statement inside the finally block in :py:class:`~aiohttp.web.RequestHandler` -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`10434`. - Changed connection creation to explicitly close sockets if an exception is raised in the event loop's ``create_connection`` method -- by :user:`top-oai`. *Related issues and pull requests on GitHub:* :issue:`10464`. Packaging updates and notes for downstreams ------------------------------------------- - Fixed test ``test_write_large_payload_deflate_compression_data_in_eof_writelines`` failing with Python 3.12.9+ or 3.13.2+ -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10423`. Miscellaneous internal changes ------------------------------ - Added human-readable error messages to the exceptions for WebSocket disconnects due to PONG not being received -- by :user:`bdraco`. Previously, the error messages were empty strings, which made it hard to determine what went wrong. *Related issues and pull requests on GitHub:* :issue:`10422`. ---- 3.11.12 (2025-02-05) ==================== Bug fixes --------- - ``MultipartForm.decode()`` now follows RFC1341 7.2.1 with a ``CRLF`` after the boundary -- by :user:`imnotjames`. *Related issues and pull requests on GitHub:* :issue:`10270`. - Restored the missing ``total_bytes`` attribute to ``EmptyStreamReader`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10387`. Features -------- - Updated :py:func:`~aiohttp.request` to make it accept ``_RequestOptions`` kwargs. -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`10300`. - Improved logging of HTTP protocol errors to include the remote address -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10332`. Improved documentation ---------------------- - Added ``aiohttp-openmetrics`` to list of third-party libraries -- by :user:`jelmer`. *Related issues and pull requests on GitHub:* :issue:`10304`. Packaging updates and notes for downstreams ------------------------------------------- - Added missing files to the source distribution to fix ``Makefile`` targets. Added a ``cythonize-nodeps`` target to run Cython without invoking pip to install dependencies. *Related issues and pull requests on GitHub:* :issue:`10366`. - Started building armv7l musllinux wheels -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10404`. Contributor-facing changes -------------------------- - The CI/CD workflow has been updated to use `upload-artifact` v4 and `download-artifact` v4 GitHub Actions -- by :user:`silamon`. *Related issues and pull requests on GitHub:* :issue:`10281`. Miscellaneous internal changes ------------------------------ - Restored support for zero copy writes when using Python 3.12 versions 3.12.9 and later or Python 3.13.2+ -- by :user:`bdraco`. Zero copy writes were previously disabled due to :cve:`2024-12254` which is resolved in these Python versions. *Related issues and pull requests on GitHub:* :issue:`10137`. ---- 3.11.11 (2024-12-18) ==================== Bug fixes --------- - Updated :py:meth:`~aiohttp.ClientSession.request` to reuse the ``quote_cookie`` setting from ``ClientSession._cookie_jar`` when processing cookies parameter. -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`10093`. - Fixed type of ``SSLContext`` for some static type checkers (e.g. pyright). *Related issues and pull requests on GitHub:* :issue:`10099`. - Updated :meth:`aiohttp.web.StreamResponse.write` annotation to also allow :class:`bytearray` and :class:`memoryview` as inputs -- by :user:`cdce8p`. *Related issues and pull requests on GitHub:* :issue:`10154`. - Fixed a hang where a connection previously used for a streaming download could be returned to the pool in a paused state. -- by :user:`javitonino`. *Related issues and pull requests on GitHub:* :issue:`10169`. Features -------- - Enabled ALPN on default SSL contexts. This improves compatibility with some proxies which don't work without this extension. -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`10156`. Miscellaneous internal changes ------------------------------ - Fixed an infinite loop that can occur when using aiohttp in combination with `async-solipsism`_ -- by :user:`bmerry`. .. _async-solipsism: https://github.com/bmerry/async-solipsism *Related issues and pull requests on GitHub:* :issue:`10149`. ---- 3.11.10 (2024-12-05) ==================== Bug fixes --------- - Fixed race condition in :class:`aiohttp.web.FileResponse` that could have resulted in an incorrect response if the file was replaced on the file system during ``prepare`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10101`, :issue:`10113`. - Replaced deprecated call to :func:`mimetypes.guess_type` with :func:`mimetypes.guess_file_type` when using Python 3.13+ -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10102`. - Disabled zero copy writes in the ``StreamWriter`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10125`. ---- 3.11.9 (2024-12-01) =================== Bug fixes --------- - Fixed invalid method logging unexpected being logged at exception level on subsequent connections -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10055`, :issue:`10076`. Miscellaneous internal changes ------------------------------ - Improved performance of parsing headers when using the C parser -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10073`. ---- 3.11.8 (2024-11-27) =================== Miscellaneous internal changes ------------------------------ - Improved performance of creating :class:`aiohttp.ClientResponse` objects when there are no cookies -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10029`. - Improved performance of creating :class:`aiohttp.ClientResponse` objects -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10030`. - Improved performances of creating objects during the HTTP request lifecycle -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10037`. - Improved performance of constructing :class:`aiohttp.web.Response` with headers -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10043`. - Improved performance of making requests when there are no auto headers to skip -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10049`. - Downgraded logging of invalid HTTP method exceptions on the first request to debug level -- by :user:`bdraco`. HTTP requests starting with an invalid method are relatively common, especially when connected to the public internet, because browsers or other clients may try to speak SSL to a plain-text server or vice-versa. These exceptions can quickly fill the log with noise when nothing is wrong. *Related issues and pull requests on GitHub:* :issue:`10055`. ---- 3.11.7 (2024-11-21) =================== Bug fixes --------- - Fixed the HTTP client not considering the connector's ``force_close`` value when setting the ``Connection`` header -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10003`. Miscellaneous internal changes ------------------------------ - Improved performance of serializing HTTP headers -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`10014`. ---- 3.11.6 (2024-11-19) =================== Bug fixes --------- - Restored the ``force_close`` method to the ``ResponseHandler`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9997`. ---- 3.11.5 (2024-11-19) =================== Bug fixes --------- - Fixed the ``ANY`` method not appearing in :meth:`~aiohttp.web.UrlDispatcher.routes` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9899`, :issue:`9987`. ---- 3.11.4 (2024-11-18) =================== Bug fixes --------- - Fixed ``StaticResource`` not allowing the ``OPTIONS`` method after calling ``set_options_route`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9972`, :issue:`9975`, :issue:`9976`. Miscellaneous internal changes ------------------------------ - Improved performance of creating web responses when there are no cookies -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9895`. ---- 3.11.3 (2024-11-18) =================== Bug fixes --------- - Removed non-existing ``__author__`` from ``dir(aiohttp)`` -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9918`. - Restored the ``FlowControlDataQueue`` class -- by :user:`bdraco`. This class is no longer used internally, and will be permanently removed in the next major version. *Related issues and pull requests on GitHub:* :issue:`9963`. Miscellaneous internal changes ------------------------------ - Improved performance of resolving resources when multiple methods are registered for the same route -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9899`. ---- 3.11.2 (2024-11-14) =================== Bug fixes --------- - Fixed improperly closed WebSocket connections generating an unhandled exception -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9883`. ---- 3.11.1 (2024-11-14) =================== Bug fixes --------- - Added a backward compatibility layer to :class:`aiohttp.RequestInfo` to allow creating these objects without a ``real_url`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9873`. ---- 3.11.0 (2024-11-13) =================== Bug fixes --------- - Raise :exc:`aiohttp.ServerFingerprintMismatch` exception on client-side if request through http proxy with mismatching server fingerprint digest: `aiohttp.ClientSession(headers=headers, connector=TCPConnector(ssl=aiohttp.Fingerprint(mismatch_digest), trust_env=True).request(...)` -- by :user:`gangj`. *Related issues and pull requests on GitHub:* :issue:`6652`. - Modified websocket :meth:`aiohttp.ClientWebSocketResponse.receive_str`, :py:meth:`aiohttp.ClientWebSocketResponse.receive_bytes`, :py:meth:`aiohttp.web.WebSocketResponse.receive_str` & :py:meth:`aiohttp.web.WebSocketResponse.receive_bytes` methods to raise new :py:exc:`aiohttp.WSMessageTypeError` exception, instead of generic :py:exc:`TypeError`, when websocket messages of incorrect types are received -- by :user:`ara-25`. *Related issues and pull requests on GitHub:* :issue:`6800`. - Made ``TestClient.app`` a ``Generic`` so type checkers will know the correct type (avoiding unneeded ``client.app is not None`` checks) -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8977`. - Fixed the keep-alive connection pool to be FIFO instead of LIFO -- by :user:`bdraco`. Keep-alive connections are more likely to be reused before they disconnect. *Related issues and pull requests on GitHub:* :issue:`9672`. Features -------- - Added ``strategy`` parameter to :meth:`aiohttp.web.StreamResponse.enable_compression` The value of this parameter is passed to the :func:`zlib.compressobj` function, allowing people to use a more sufficient compression algorithm for their data served by :mod:`aiohttp.web` -- by :user:`shootkin` *Related issues and pull requests on GitHub:* :issue:`6257`. - Added ``server_hostname`` parameter to ``ws_connect``. *Related issues and pull requests on GitHub:* :issue:`7941`. - Exported :py:class:`~aiohttp.ClientWSTimeout` to top-level namespace -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8612`. - Added ``secure``/``httponly``/``samesite`` parameters to ``.del_cookie()`` -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8956`. - Updated :py:class:`~aiohttp.ClientSession`'s auth logic to include default auth only if the request URL's origin matches _base_url; otherwise, the auth will not be included -- by :user:`MaximZemskov` *Related issues and pull requests on GitHub:* :issue:`8966`, :issue:`9466`. - Added ``proxy`` and ``proxy_auth`` parameters to :py:class:`~aiohttp.ClientSession` -- by :user:`meshya`. *Related issues and pull requests on GitHub:* :issue:`9207`. - Added ``default_to_multipart`` parameter to ``FormData``. *Related issues and pull requests on GitHub:* :issue:`9335`. - Added :py:meth:`~aiohttp.ClientWebSocketResponse.send_frame` and :py:meth:`~aiohttp.web.WebSocketResponse.send_frame` for WebSockets -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9348`. - Updated :py:class:`~aiohttp.ClientSession` to support paths in ``base_url`` parameter. ``base_url`` paths must end with a ``/`` -- by :user:`Cycloctane`. *Related issues and pull requests on GitHub:* :issue:`9530`. - Improved performance of reading WebSocket messages with a Cython implementation -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9543`, :issue:`9554`, :issue:`9556`, :issue:`9558`, :issue:`9636`, :issue:`9649`, :issue:`9781`. - Added ``writer_limit`` to the :py:class:`~aiohttp.web.WebSocketResponse` to be able to adjust the limit before the writer forces the buffer to be drained -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9572`. - Added an :attr:`~aiohttp.abc.AbstractAccessLogger.enabled` property to :class:`aiohttp.abc.AbstractAccessLogger` to dynamically check if logging is enabled -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9822`. Deprecations (removal in next major release) -------------------------------------------- - Deprecate obsolete `timeout: float` and `receive_timeout: Optional[float]` in :py:meth:`~aiohttp.ClientSession.ws_connect`. Change default websocket receive timeout from `None` to `10.0`. *Related issues and pull requests on GitHub:* :issue:`3945`. Removals and backward incompatible breaking changes --------------------------------------------------- - Dropped support for Python 3.8 -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8797`. - Increased minimum yarl version to 1.17.0 -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8909`, :issue:`9079`, :issue:`9305`, :issue:`9574`. - Removed the ``is_ipv6_address`` and ``is_ip4_address`` helpers are they are no longer used -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9344`. - Changed ``ClientRequest.connection_key`` to be a `NamedTuple` to improve client performance -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9365`. - ``FlowControlDataQueue`` has been replaced with the ``WebSocketDataQueue`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9685`. - Changed ``ClientRequest.request_info`` to be a `NamedTuple` to improve client performance -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9692`. Packaging updates and notes for downstreams ------------------------------------------- - Switched to using the :mod:`propcache ` package for property caching -- by :user:`bdraco`. The :mod:`propcache ` package is derived from the property caching code in :mod:`yarl` and has been broken out to avoid maintaining it for multiple projects. *Related issues and pull requests on GitHub:* :issue:`9394`. - Separated ``aiohttp.http_websocket`` into multiple files to make it easier to maintain -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9542`, :issue:`9552`. Contributor-facing changes -------------------------- - Changed diagram images generator from ``blockdiag`` to ``GraphViz``. Generating documentation now requires the GraphViz executable to be included in $PATH or sphinx build configuration. *Related issues and pull requests on GitHub:* :issue:`9359`. Miscellaneous internal changes ------------------------------ - Added flake8 settings to avoid some forms of implicit concatenation. -- by :user:`booniepepper`. *Related issues and pull requests on GitHub:* :issue:`7731`. - Enabled keep-alive support on proxies (which was originally disabled several years ago) -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8920`. - Changed web entry point to not listen on TCP when only a Unix path is passed -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9033`. - Disabled automatic retries of failed requests in :class:`aiohttp.test_utils.TestClient`'s client session (which could potentially hide errors in tests) -- by :user:`ShubhAgarwal-dev`. *Related issues and pull requests on GitHub:* :issue:`9141`. - Changed web ``keepalive_timeout`` default to around an hour in order to reduce race conditions on reverse proxies -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9285`. - Reduced memory required for stream objects created during the client request lifecycle -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9407`. - Improved performance of the internal ``DataQueue`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9659`. - Improved performance of calling ``receive`` for WebSockets for the most common message types -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9679`. - Replace internal helper methods ``method_must_be_empty_body`` and ``status_code_must_be_empty_body`` with simple `set` lookups -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9722`. - Improved performance of :py:class:`aiohttp.BaseConnector` when there is no ``limit_per_host`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9756`. - Improved performance of sending HTTP requests when there is no body -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9757`. - Improved performance of the ``WebsocketWriter`` when the protocol is not paused -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9796`. - Implemented zero copy writes for ``StreamWriter`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9839`. ---- 3.10.11 (2024-11-13) ==================== Bug fixes --------- - Authentication provided by a redirect now takes precedence over provided ``auth`` when making requests with the client -- by :user:`PLPeeters`. *Related issues and pull requests on GitHub:* :issue:`9436`. - Fixed :py:meth:`WebSocketResponse.close() ` to discard non-close messages within its timeout window after sending close -- by :user:`lenard-mosys`. *Related issues and pull requests on GitHub:* :issue:`9506`. - Fixed a deadlock that could occur while attempting to get a new connection slot after a timeout -- by :user:`bdraco`. The connector was not cancellation-safe. *Related issues and pull requests on GitHub:* :issue:`9670`, :issue:`9671`. - Fixed the WebSocket flow control calculation undercounting with multi-byte data -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9686`. - Fixed incorrect parsing of chunk extensions with the pure Python parser -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9851`. - Fixed system routes polluting the middleware cache -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9852`. Removals and backward incompatible breaking changes --------------------------------------------------- - Improved performance of the connector when a connection can be reused -- by :user:`bdraco`. If ``BaseConnector.connect`` has been subclassed and replaced with custom logic, the ``ceil_timeout`` must be added. *Related issues and pull requests on GitHub:* :issue:`9600`. Miscellaneous internal changes ------------------------------ - Improved performance of the client request lifecycle when there are no cookies -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9470`. - Improved performance of sending client requests when the writer can finish synchronously -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9485`. - Improved performance of serializing HTTP headers -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9603`. - Passing ``enable_cleanup_closed`` to :py:class:`aiohttp.TCPConnector` is now ignored on Python 3.12.7+ and 3.13.1+ since the underlying bug that caused asyncio to leak SSL connections has been fixed upstream -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9726`, :issue:`9736`. ---- 3.10.10 (2024-10-10) ==================== Bug fixes --------- - Fixed error messages from :py:class:`~aiohttp.resolver.AsyncResolver` being swallowed -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9451`, :issue:`9455`. Features -------- - Added :exc:`aiohttp.ClientConnectorDNSError` for differentiating DNS resolution errors from other connector errors -- by :user:`mstojcevich`. *Related issues and pull requests on GitHub:* :issue:`8455`. Miscellaneous internal changes ------------------------------ - Simplified DNS resolution throttling code to reduce chance of race conditions -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9454`. ---- 3.10.9 (2024-10-04) =================== Bug fixes --------- - Fixed proxy headers being used in the ``ConnectionKey`` hash when a proxy was not being used -- by :user:`bdraco`. If default headers are used, they are also used for proxy headers. This could have led to creating connections that were not needed when one was already available. *Related issues and pull requests on GitHub:* :issue:`9368`. - Widened the type of the ``trace_request_ctx`` parameter of :meth:`ClientSession.request() ` and friends -- by :user:`layday`. *Related issues and pull requests on GitHub:* :issue:`9397`. Removals and backward incompatible breaking changes --------------------------------------------------- - Fixed failure to try next host after single-host connection timeout -- by :user:`brettdh`. The default client :class:`aiohttp.ClientTimeout` params has changed to include a ``sock_connect`` timeout of 30 seconds so that this correct behavior happens by default. *Related issues and pull requests on GitHub:* :issue:`7342`. Miscellaneous internal changes ------------------------------ - Improved performance of resolving hosts with Python 3.12+ -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9342`. - Reduced memory required for timer objects created during the client request lifecycle -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9406`. ---- 3.10.8 (2024-09-28) =================== Bug fixes --------- - Fixed cancellation leaking upwards on timeout -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9326`. ---- 3.10.7 (2024-09-27) =================== Bug fixes --------- - Fixed assembling the :class:`~yarl.URL` for web requests when the host contains a non-default port or IPv6 address -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9309`. Miscellaneous internal changes ------------------------------ - Improved performance of determining if a URL is absolute -- by :user:`bdraco`. The property :attr:`~yarl.URL.absolute` is more performant than the method ``URL.is_absolute()`` and preferred when newer versions of yarl are used. *Related issues and pull requests on GitHub:* :issue:`9171`. - Replaced code that can now be handled by ``yarl`` -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9301`. ---- 3.10.6 (2024-09-24) =================== Bug fixes --------- - Added :exc:`aiohttp.ClientConnectionResetError`. Client code that previously threw :exc:`ConnectionResetError` will now throw this -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9137`. - Fixed an unclosed transport ``ResourceWarning`` on web handlers -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8875`. - Fixed resolve_host() 'Task was destroyed but is pending' errors -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8967`. - Fixed handling of some file-like objects (e.g. ``tarfile.extractfile()``) which raise ``AttributeError`` instead of ``OSError`` when ``fileno`` fails for streaming payload data -- by :user:`ReallyReivax`. *Related issues and pull requests on GitHub:* :issue:`6732`. - Fixed web router not matching pre-encoded URLs (requires yarl 1.9.6+) -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8898`, :issue:`9267`. - Fixed an error when trying to add a route for multiple methods with a path containing a regex pattern -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8998`. - Fixed ``Response.text`` when body is a ``Payload`` -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`6485`. - Fixed compressed requests failing when no body was provided -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9108`. - Fixed client incorrectly reusing a connection when the previous message had not been fully sent -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8992`. - Fixed race condition that could cause server to close connection incorrectly at keepalive timeout -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9140`. - Fixed Python parser chunked handling with multiple Transfer-Encoding values -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8823`. - Fixed error handling after 100-continue so server sends 500 response instead of disconnecting -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8876`. - Stopped adding a default Content-Type header when response has no content -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8858`. - Added support for URL credentials with empty (zero-length) username, e.g. ``https://:password@host`` -- by :user:`shuckc` *Related issues and pull requests on GitHub:* :issue:`6494`. - Stopped logging exceptions from ``web.run_app()`` that would be raised regardless -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`6807`. - Implemented binding to IPv6 addresses in the pytest server fixture. *Related issues and pull requests on GitHub:* :issue:`4650`. - Fixed the incorrect use of flags for ``getnameinfo()`` in the Resolver --by :user:`GitNMLee` Link-Local IPv6 addresses can now be handled by the Resolver correctly. *Related issues and pull requests on GitHub:* :issue:`9032`. - Fixed StreamResponse.prepared to return True after EOF is sent -- by :user:`arthurdarcet`. *Related issues and pull requests on GitHub:* :issue:`5343`. - Changed ``make_mocked_request()`` to use empty payload by default -- by :user:`rahulnht`. *Related issues and pull requests on GitHub:* :issue:`7167`. - Used more precise type for ``ClientResponseError.headers``, fixing some type errors when using them -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8768`. - Changed behavior when returning an invalid response to send a 500 response -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8845`. - Fixed response reading from closed session to throw an error immediately instead of timing out -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8878`. - Fixed ``CancelledError`` from one cleanup context stopping other contexts from completing -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8908`. - Fixed changing scheme/host in ``Response.clone()`` for absolute URLs -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8990`. - Fixed ``Site.name`` when host is an empty string -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8929`. - Updated Python parser to reject messages after a close message, matching C parser behaviour -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9018`. - Fixed creation of ``SSLContext`` inside of :py:class:`aiohttp.TCPConnector` with multiple event loops in different threads -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9029`. - Fixed (on Python 3.11+) some edge cases where a task cancellation may get incorrectly suppressed -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9030`. - Fixed exception information getting lost on ``HttpProcessingError`` -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9052`. - Fixed ``If-None-Match`` not using weak comparison -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9063`. - Fixed badly encoded charset crashing when getting response text instead of falling back to charset detector. *Related issues and pull requests on GitHub:* :issue:`9160`. - Rejected `\n` in `reason` values to avoid sending broken HTTP messages -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9167`. - Changed :py:meth:`ClientResponse.raise_for_status() ` to only release the connection when invoked outside an ``async with`` context -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9239`. Features -------- - Improved type on ``params`` to match the underlying type allowed by ``yarl`` -- by :user:`lpetre`. *Related issues and pull requests on GitHub:* :issue:`8564`. - Declared Python 3.13 supported -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8748`. Removals and backward incompatible breaking changes --------------------------------------------------- - Improved middleware performance -- by :user:`bdraco`. The ``set_current_app`` method was removed from ``UrlMappingMatchInfo`` because it is no longer used, and it was unlikely external caller would ever use it. *Related issues and pull requests on GitHub:* :issue:`9200`. - Increased minimum yarl version to 1.12.0 -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9267`. Improved documentation ---------------------- - Clarified that ``GracefulExit`` needs to be handled in ``AppRunner`` and ``ServerRunner`` when using ``handle_signals=True``. -- by :user:`Daste745` *Related issues and pull requests on GitHub:* :issue:`4414`. - Clarified that auth parameter in ClientSession will persist and be included with any request to any origin, even during redirects to different origins. -- by :user:`MaximZemskov`. *Related issues and pull requests on GitHub:* :issue:`6764`. - Clarified which timeout exceptions happen on which timeouts -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8968`. - Updated ``ClientSession`` parameters to match current code -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8991`. Packaging updates and notes for downstreams ------------------------------------------- - Fixed ``test_client_session_timeout_zero`` to not require internet access -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`9004`. Miscellaneous internal changes ------------------------------ - Improved performance of making requests when there are no auto headers to skip -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8847`. - Exported ``aiohttp.TraceRequestHeadersSentParams`` -- by :user:`Hadock-is-ok`. *Related issues and pull requests on GitHub:* :issue:`8947`. - Avoided tracing overhead in the http writer when there are no active traces -- by user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9031`. - Improved performance of reify Cython implementation -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9054`. - Use :meth:`URL.extend_query() ` to extend query params (requires yarl 1.11.0+) -- by :user:`bdraco`. If yarl is older than 1.11.0, the previous slower hand rolled version will be used. *Related issues and pull requests on GitHub:* :issue:`9068`. - Improved performance of checking if a host is an IP Address -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9095`. - Significantly improved performance of middlewares -- by :user:`bdraco`. The construction of the middleware wrappers is now cached and is built once per handler instead of on every request. *Related issues and pull requests on GitHub:* :issue:`9158`, :issue:`9170`. - Improved performance of web requests -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9168`, :issue:`9169`, :issue:`9172`, :issue:`9174`, :issue:`9175`, :issue:`9241`. - Improved performance of starting web requests when there is no response prepare hook -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9173`. - Significantly improved performance of expiring cookies -- by :user:`bdraco`. Expiring cookies has been redesigned to use :mod:`heapq` instead of a linear search, to better scale. *Related issues and pull requests on GitHub:* :issue:`9203`. - Significantly sped up filtering cookies -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`9204`. ---- 3.10.5 (2024-08-19) ========================= Bug fixes --------- - Fixed :meth:`aiohttp.ClientResponse.json()` not setting ``status`` when :exc:`aiohttp.ContentTypeError` is raised -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8742`. Miscellaneous internal changes ------------------------------ - Improved performance of the WebSocket reader -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8736`, :issue:`8747`. ---- 3.10.4 (2024-08-17) =================== Bug fixes --------- - Fixed decoding base64 chunk in BodyPartReader -- by :user:`hyzyla`. *Related issues and pull requests on GitHub:* :issue:`3867`. - Fixed a race closing the server-side WebSocket where the close code would not reach the client -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8680`. - Fixed unconsumed exceptions raised by the WebSocket heartbeat -- by :user:`bdraco`. If the heartbeat ping raised an exception, it would not be consumed and would be logged as an warning. *Related issues and pull requests on GitHub:* :issue:`8685`. - Fixed an edge case in the Python parser when chunk separators happen to align with network chunks -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8720`. Improved documentation ---------------------- - Added ``aiohttp-apischema`` to supported libraries -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8700`. Miscellaneous internal changes ------------------------------ - Improved performance of starting request handlers with Python 3.12+ -- by :user:`bdraco`. This change is a followup to :issue:`8661` to make the same optimization for Python 3.12+ where the request is connected. *Related issues and pull requests on GitHub:* :issue:`8681`. ---- 3.10.3 (2024-08-10) ======================== Bug fixes --------- - Fixed multipart reading when stream buffer splits the boundary over several read() calls -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8653`. - Fixed :py:class:`aiohttp.TCPConnector` doing blocking I/O in the event loop to create the ``SSLContext`` -- by :user:`bdraco`. The blocking I/O would only happen once per verify mode. However, it could cause the event loop to block for a long time if the ``SSLContext`` creation is slow, which is more likely during startup when the disk cache is not yet present. *Related issues and pull requests on GitHub:* :issue:`8672`. Miscellaneous internal changes ------------------------------ - Improved performance of :py:meth:`~aiohttp.ClientWebSocketResponse.receive` and :py:meth:`~aiohttp.web.WebSocketResponse.receive` when there is no timeout. -- by :user:`bdraco`. The timeout context manager is now avoided when there is no timeout as it accounted for up to 50% of the time spent in the :py:meth:`~aiohttp.ClientWebSocketResponse.receive` and :py:meth:`~aiohttp.web.WebSocketResponse.receive` methods. *Related issues and pull requests on GitHub:* :issue:`8660`. - Improved performance of starting request handlers with Python 3.12+ -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8661`. - Improved performance of HTTP keep-alive checks -- by :user:`bdraco`. Previously, when processing a request for a keep-alive connection, the keep-alive check would happen every second; the check is now rescheduled if it fires too early instead. *Related issues and pull requests on GitHub:* :issue:`8662`. - Improved performance of generating random WebSocket mask -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8667`. ---- 3.10.2 (2024-08-08) =================== Bug fixes --------- - Fixed server checks for circular symbolic links to be compatible with Python 3.13 -- by :user:`steverep`. *Related issues and pull requests on GitHub:* :issue:`8565`. - Fixed request body not being read when ignoring an Upgrade request -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8597`. - Fixed an edge case where shutdown would wait for timeout when the handler was already completed -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8611`. - Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8632`. - Fixed WebSocket ping tasks being prematurely garbage collected -- by :user:`bdraco`. There was a small risk that WebSocket ping tasks would be prematurely garbage collected because the event loop only holds a weak reference to the task. The garbage collection risk has been fixed by holding a strong reference to the task. Additionally, the task is now scheduled eagerly with Python 3.12+ to increase the chance it can be completed immediately and avoid having to hold any references to the task. *Related issues and pull requests on GitHub:* :issue:`8641`. - Fixed incorrectly following symlinks for compressed file variants -- by :user:`steverep`. *Related issues and pull requests on GitHub:* :issue:`8652`. Removals and backward incompatible breaking changes --------------------------------------------------- - Removed ``Request.wait_for_disconnection()``, which was mistakenly added briefly in 3.10.0 -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8636`. Contributor-facing changes -------------------------- - Fixed monkey patches for ``Path.stat()`` and ``Path.is_dir()`` for Python 3.13 compatibility -- by :user:`steverep`. *Related issues and pull requests on GitHub:* :issue:`8551`. Miscellaneous internal changes ------------------------------ - Improved WebSocket performance when messages are sent or received frequently -- by :user:`bdraco`. The WebSocket heartbeat scheduling algorithm was improved to reduce the ``asyncio`` scheduling overhead by decreasing the number of ``asyncio.TimerHandle`` creations and cancellations. *Related issues and pull requests on GitHub:* :issue:`8608`. - Minor improvements to various type annotations -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8634`. ---- 3.10.1 (2024-08-03) ======================== Bug fixes --------- - Fixed WebSocket server heartbeat timeout logic to terminate :py:meth:`~aiohttp.ClientWebSocketResponse.receive` and return :py:class:`~aiohttp.ServerTimeoutError` -- by :user:`arcivanov`. When a WebSocket pong message was not received, the :py:meth:`~aiohttp.ClientWebSocketResponse.receive` operation did not terminate. This change causes ``_pong_not_received`` to feed the ``reader`` an error message, causing pending :py:meth:`~aiohttp.ClientWebSocketResponse.receive` to terminate and return the error message. The error message contains the exception :py:class:`~aiohttp.ServerTimeoutError`. *Related issues and pull requests on GitHub:* :issue:`8540`. - Fixed url dispatcher index not matching when a variable is preceded by a fixed string after a slash -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8566`. Removals and backward incompatible breaking changes --------------------------------------------------- - Creating :py:class:`aiohttp.TCPConnector`, :py:class:`aiohttp.ClientSession`, :py:class:`~aiohttp.resolver.ThreadedResolver` :py:class:`aiohttp.web.Server`, or :py:class:`aiohttp.CookieJar` instances without a running event loop now raises a :exc:`RuntimeError` -- by :user:`asvetlov`. Creating these objects without a running event loop was deprecated in :issue:`3372` which was released in version 3.5.0. This change first appeared in version 3.10.0 as :issue:`6378`. *Related issues and pull requests on GitHub:* :issue:`8555`, :issue:`8583`. ---- 3.10.0 (2024-07-30) ======================== Bug fixes --------- - Fixed server response headers for ``Content-Type`` and ``Content-Encoding`` for static compressed files -- by :user:`steverep`. Server will now respond with a ``Content-Type`` appropriate for the compressed file (e.g. ``"application/gzip"``), and omit the ``Content-Encoding`` header. Users should expect that most clients will no longer decompress such responses by default. *Related issues and pull requests on GitHub:* :issue:`4462`. - Fixed duplicate cookie expiration calls in the CookieJar implementation *Related issues and pull requests on GitHub:* :issue:`7784`. - Adjusted ``FileResponse`` to check file existence and access when preparing the response -- by :user:`steverep`. The :py:class:`~aiohttp.web.FileResponse` class was modified to respond with 403 Forbidden or 404 Not Found as appropriate. Previously, it would cause a server error if the path did not exist or could not be accessed. Checks for existence, non-regular files, and permissions were expected to be done in the route handler. For static routes, this now permits a compressed file to exist without its uncompressed variant and still be served. In addition, this changes the response status for files without read permission to 403, and for non-regular files from 404 to 403 for consistency. *Related issues and pull requests on GitHub:* :issue:`8182`. - Fixed ``AsyncResolver`` to match ``ThreadedResolver`` behavior -- by :user:`bdraco`. On system with IPv6 support, the :py:class:`~aiohttp.resolver.AsyncResolver` would not fallback to providing A records when AAAA records were not available. Additionally, unlike the :py:class:`~aiohttp.resolver.ThreadedResolver`, the :py:class:`~aiohttp.resolver.AsyncResolver` did not handle link-local addresses correctly. This change makes the behavior consistent with the :py:class:`~aiohttp.resolver.ThreadedResolver`. *Related issues and pull requests on GitHub:* :issue:`8270`. - Fixed ``ws_connect`` not respecting `receive_timeout`` on WS(S) connection. -- by :user:`arcivanov`. *Related issues and pull requests on GitHub:* :issue:`8444`. - Removed blocking I/O in the event loop for static resources and refactored exception handling -- by :user:`steverep`. File system calls when handling requests for static routes were moved to a separate thread to potentially improve performance. Exception handling was tightened in order to only return 403 Forbidden or 404 Not Found responses for expected scenarios; 500 Internal Server Error would be returned for any unknown errors. *Related issues and pull requests on GitHub:* :issue:`8507`. Features -------- - Added a Request.wait_for_disconnection() method, as means of allowing request handlers to be notified of premature client disconnections. *Related issues and pull requests on GitHub:* :issue:`2492`. - Added 5 new exceptions: :py:exc:`~aiohttp.InvalidUrlClientError`, :py:exc:`~aiohttp.RedirectClientError`, :py:exc:`~aiohttp.NonHttpUrlClientError`, :py:exc:`~aiohttp.InvalidUrlRedirectClientError`, :py:exc:`~aiohttp.NonHttpUrlRedirectClientError` :py:exc:`~aiohttp.InvalidUrlRedirectClientError`, :py:exc:`~aiohttp.NonHttpUrlRedirectClientError` are raised instead of :py:exc:`ValueError` or :py:exc:`~aiohttp.InvalidURL` when the redirect URL is invalid. Classes :py:exc:`~aiohttp.InvalidUrlClientError`, :py:exc:`~aiohttp.RedirectClientError`, :py:exc:`~aiohttp.NonHttpUrlClientError` are base for them. The :py:exc:`~aiohttp.InvalidURL` now exposes a ``description`` property with the text explanation of the error details. -- by :user:`setla`, :user:`AraHaan`, and :user:`bdraco` *Related issues and pull requests on GitHub:* :issue:`2507`, :issue:`3315`, :issue:`6722`, :issue:`8481`, :issue:`8482`. - Added a feature to retry closed connections automatically for idempotent methods. -- by :user:`Dreamsorcerer` *Related issues and pull requests on GitHub:* :issue:`7297`. - Implemented filter_cookies() with domain-matching and path-matching on the keys, instead of testing every single cookie. This may break existing cookies that have been saved with `CookieJar.save()`. Cookies can be migrated with this script:: import pickle with file_path.open("rb") as f: cookies = pickle.load(f) morsels = [(name, m) for c in cookies.values() for name, m in c.items()] cookies.clear() for name, m in morsels: cookies[(m["domain"], m["path"].rstrip("/"))][name] = m with file_path.open("wb") as f: pickle.dump(cookies, f, pickle.HIGHEST_PROTOCOL) *Related issues and pull requests on GitHub:* :issue:`7583`, :issue:`8535`. - Separated connection and socket timeout errors, from ServerTimeoutError. *Related issues and pull requests on GitHub:* :issue:`7801`. - Implemented happy eyeballs *Related issues and pull requests on GitHub:* :issue:`7954`. - Added server capability to check for static files with Brotli compression via a ``.br`` extension -- by :user:`steverep`. *Related issues and pull requests on GitHub:* :issue:`8062`. Removals and backward incompatible breaking changes --------------------------------------------------- - The shutdown logic in 3.9 waited on all tasks, which caused issues with some libraries. In 3.10 we've changed this logic to only wait on request handlers. This means that it's important for developers to correctly handle the lifecycle of background tasks using a library such as ``aiojobs``. If an application is using ``handler_cancellation=True`` then it is also a good idea to ensure that any :func:`asyncio.shield` calls are replaced with :func:`aiojobs.aiohttp.shield`. Please read the updated documentation on these points: \ https://docs.aiohttp.org/en/stable/web_advanced.html#graceful-shutdown \ https://docs.aiohttp.org/en/stable/web_advanced.html#web-handler-cancellation -- by :user:`Dreamsorcerer` *Related issues and pull requests on GitHub:* :issue:`8495`. Improved documentation ---------------------- - Added documentation for ``aiohttp.web.FileResponse``. *Related issues and pull requests on GitHub:* :issue:`3958`. - Improved the docs for the `ssl` params. *Related issues and pull requests on GitHub:* :issue:`8403`. Contributor-facing changes -------------------------- - Enabled HTTP parser tests originally intended for 3.9.2 release -- by :user:`pajod`. *Related issues and pull requests on GitHub:* :issue:`8088`. Miscellaneous internal changes ------------------------------ - Improved URL handler resolution time by indexing resources in the UrlDispatcher. For applications with a large number of handlers, this should increase performance significantly. -- by :user:`bdraco` *Related issues and pull requests on GitHub:* :issue:`7829`. - Added `nacl_middleware `_ to the list of middlewares in the third party section of the documentation. *Related issues and pull requests on GitHub:* :issue:`8346`. - Minor improvements to static typing -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8364`. - Added a 3.11-specific overloads to ``ClientSession`` -- by :user:`max-muoto`. *Related issues and pull requests on GitHub:* :issue:`8463`. - Simplified path checks for ``UrlDispatcher.add_static()`` method -- by :user:`steverep`. *Related issues and pull requests on GitHub:* :issue:`8491`. - Avoided creating a future on every websocket receive -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8498`. - Updated identity checks for all ``WSMsgType`` type compares -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8501`. - When using Python 3.12 or later, the writer is no longer scheduled on the event loop if it can finish synchronously. Avoiding event loop scheduling reduces latency and improves performance. -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8510`. - Restored :py:class:`~aiohttp.resolver.AsyncResolver` to be the default resolver. -- by :user:`bdraco`. :py:class:`~aiohttp.resolver.AsyncResolver` was disabled by default because of IPv6 compatibility issues. These issues have been resolved and :py:class:`~aiohttp.resolver.AsyncResolver` is again now the default resolver. *Related issues and pull requests on GitHub:* :issue:`8522`. ---- 3.9.5 (2024-04-16) ================== Bug fixes --------- - Fixed "Unclosed client session" when initialization of :py:class:`~aiohttp.ClientSession` fails -- by :user:`NewGlad`. *Related issues and pull requests on GitHub:* :issue:`8253`. - Fixed regression (from :pr:`8280`) with adding ``Content-Disposition`` to the ``form-data`` part after appending to writer -- by :user:`Dreamsorcerer`/:user:`Olegt0rr`. *Related issues and pull requests on GitHub:* :issue:`8332`. - Added default ``Content-Disposition`` in ``multipart/form-data`` responses to avoid broken form-data responses -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8335`. ---- 3.9.4 (2024-04-11) ================== Bug fixes --------- - The asynchronous internals now set the underlying causes when assigning exceptions to the future objects -- by :user:`webknjaz`. *Related issues and pull requests on GitHub:* :issue:`8089`. - Treated values of ``Accept-Encoding`` header as case-insensitive when checking for gzip files -- by :user:`steverep`. *Related issues and pull requests on GitHub:* :issue:`8104`. - Improved the DNS resolution performance on cache hit -- by :user:`bdraco`. This is achieved by avoiding an :mod:`asyncio` task creation in this case. *Related issues and pull requests on GitHub:* :issue:`8163`. - Changed the type annotations to allow ``dict`` on :meth:`aiohttp.MultipartWriter.append`, :meth:`aiohttp.MultipartWriter.append_json` and :meth:`aiohttp.MultipartWriter.append_form` -- by :user:`cakemanny` *Related issues and pull requests on GitHub:* :issue:`7741`. - Ensure websocket transport is closed when client does not close it -- by :user:`bdraco`. The transport could remain open if the client did not close it. This change ensures the transport is closed when the client does not close it. *Related issues and pull requests on GitHub:* :issue:`8200`. - Leave websocket transport open if receive times out or is cancelled -- by :user:`bdraco`. This restores the behavior prior to the change in #7978. *Related issues and pull requests on GitHub:* :issue:`8251`. - Fixed content not being read when an upgrade request was not supported with the pure Python implementation. -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8252`. - Fixed a race condition with incoming connections during server shutdown -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8271`. - Fixed ``multipart/form-data`` compliance with :rfc:`7578` -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8280`. - Fixed blocking I/O in the event loop while processing files in a POST request -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8283`. - Escaped filenames in static view -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8317`. - Fixed the pure python parser to mark a connection as closing when a response has no length -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8320`. Features -------- - Upgraded *llhttp* to 9.2.1, and started rejecting obsolete line folding in Python parser to match -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8146`, :issue:`8292`. Deprecations (removal in next major release) -------------------------------------------- - Deprecated ``content_transfer_encoding`` parameter in :py:meth:`FormData.add_field() ` -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8280`. Improved documentation ---------------------- - Added a note about canceling tasks to avoid delaying server shutdown -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8267`. Contributor-facing changes -------------------------- - The pull request template is now asking the contributors to answer a question about the long-term maintenance challenges they envision as a result of merging their patches -- by :user:`webknjaz`. *Related issues and pull requests on GitHub:* :issue:`8099`. - Updated CI and documentation to use NPM clean install and upgrade node to version 18 -- by :user:`steverep`. *Related issues and pull requests on GitHub:* :issue:`8116`. - A pytest fixture ``hello_txt`` was introduced to aid static file serving tests in :file:`test_web_sendfile_functional.py`. It dynamically provisions ``hello.txt`` file variants shared across the tests in the module. -- by :user:`steverep` *Related issues and pull requests on GitHub:* :issue:`8136`. Packaging updates and notes for downstreams ------------------------------------------- - Added an ``internal`` pytest marker for tests which should be skipped by packagers (use ``-m 'not internal'`` to disable them) -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8299`. ---- 3.9.3 (2024-01-29) ================== Bug fixes --------- - Fixed backwards compatibility breakage (in 3.9.2) of ``ssl`` parameter when set outside of ``ClientSession`` (e.g. directly in ``TCPConnector``) -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`8097`, :issue:`8098`. Miscellaneous internal changes ------------------------------ - Improved test suite handling of paths and temp files to consistently use pathlib and pytest fixtures. *Related issues and pull requests on GitHub:* :issue:`3957`. ---- 3.9.2 (2024-01-28) ================== Bug fixes --------- - Fixed server-side websocket connection leak. *Related issues and pull requests on GitHub:* :issue:`7978`. - Fixed ``web.FileResponse`` doing blocking I/O in the event loop. *Related issues and pull requests on GitHub:* :issue:`8012`. - Fixed double compress when compression enabled and compressed file exists in server file responses. *Related issues and pull requests on GitHub:* :issue:`8014`. - Added runtime type check for ``ClientSession`` ``timeout`` parameter. *Related issues and pull requests on GitHub:* :issue:`8021`. - Fixed an unhandled exception in the Python HTTP parser on header lines starting with a colon -- by :user:`pajod`. Invalid request lines with anything but a dot between the HTTP major and minor version are now rejected. Invalid header field names containing question mark or slash are now rejected. Such requests are incompatible with :rfc:`9110#section-5.6.2` and are not known to be of any legitimate use. *Related issues and pull requests on GitHub:* :issue:`8074`. - Improved validation of paths for static resources requests to the server -- by :user:`bdraco`. *Related issues and pull requests on GitHub:* :issue:`8079`. Features -------- - Added support for passing :py:data:`True` to ``ssl`` parameter in ``ClientSession`` while deprecating :py:data:`None` -- by :user:`xiangyan99`. *Related issues and pull requests on GitHub:* :issue:`7698`. Breaking changes ---------------- - Fixed an unhandled exception in the Python HTTP parser on header lines starting with a colon -- by :user:`pajod`. Invalid request lines with anything but a dot between the HTTP major and minor version are now rejected. Invalid header field names containing question mark or slash are now rejected. Such requests are incompatible with :rfc:`9110#section-5.6.2` and are not known to be of any legitimate use. *Related issues and pull requests on GitHub:* :issue:`8074`. Improved documentation ---------------------- - Fixed examples of ``fallback_charset_resolver`` function in the :doc:`client_advanced` document. -- by :user:`henry0312`. *Related issues and pull requests on GitHub:* :issue:`7995`. - The Sphinx setup was updated to avoid showing the empty changelog draft section in the tagged release documentation builds on Read The Docs -- by :user:`webknjaz`. *Related issues and pull requests on GitHub:* :issue:`8067`. Packaging updates and notes for downstreams ------------------------------------------- - The changelog categorization was made clearer. The contributors can now mark their fragment files more accurately -- by :user:`webknjaz`. The new category tags are: * ``bugfix`` * ``feature`` * ``deprecation`` * ``breaking`` (previously, ``removal``) * ``doc`` * ``packaging`` * ``contrib`` * ``misc`` *Related issues and pull requests on GitHub:* :issue:`8066`. Contributor-facing changes -------------------------- - Updated :ref:`contributing/Tests coverage ` section to show how we use ``codecov`` -- by :user:`Dreamsorcerer`. *Related issues and pull requests on GitHub:* :issue:`7916`. - The changelog categorization was made clearer. The contributors can now mark their fragment files more accurately -- by :user:`webknjaz`. The new category tags are: * ``bugfix`` * ``feature`` * ``deprecation`` * ``breaking`` (previously, ``removal``) * ``doc`` * ``packaging`` * ``contrib`` * ``misc`` *Related issues and pull requests on GitHub:* :issue:`8066`. Miscellaneous internal changes ------------------------------ - Replaced all ``tmpdir`` fixtures with ``tmp_path`` in test suite. *Related issues and pull requests on GitHub:* :issue:`3551`. ---- 3.9.1 (2023-11-26) ================== Bugfixes -------- - Fixed importing aiohttp under PyPy on Windows. `#7848 `_ - Fixed async concurrency safety in websocket compressor. `#7865 `_ - Fixed ``ClientResponse.close()`` releasing the connection instead of closing. `#7869 `_ - Fixed a regression where connection may get closed during upgrade. -- by :user:`Dreamsorcerer` `#7879 `_ - Fixed messages being reported as upgraded without an Upgrade header in Python parser. -- by :user:`Dreamsorcerer` `#7895 `_ ---- 3.9.0 (2023-11-18) ================== Features -------- - Introduced ``AppKey`` for static typing support of ``Application`` storage. See https://docs.aiohttp.org/en/stable/web_advanced.html#application-s-config `#5864 `_ - Added a graceful shutdown period which allows pending tasks to complete before the application's cleanup is called. The period can be adjusted with the ``shutdown_timeout`` parameter. -- by :user:`Dreamsorcerer`. See https://docs.aiohttp.org/en/latest/web_advanced.html#graceful-shutdown `#7188 `_ - Added `handler_cancellation `_ parameter to cancel web handler on client disconnection. -- by :user:`mosquito` This (optionally) reintroduces a feature removed in a previous release. Recommended for those looking for an extra level of protection against denial-of-service attacks. `#7056 `_ - Added support for setting response header parameters ``max_line_size`` and ``max_field_size``. `#2304 `_ - Added ``auto_decompress`` parameter to ``ClientSession.request`` to override ``ClientSession._auto_decompress``. -- by :user:`Daste745` `#3751 `_ - Changed ``raise_for_status`` to allow a coroutine. `#3892 `_ - Added client brotli compression support (optional with runtime check). `#5219 `_ - Added ``client_max_size`` to ``BaseRequest.clone()`` to allow overriding the request body size. -- :user:`anesabml`. `#5704 `_ - Added a middleware type alias ``aiohttp.typedefs.Middleware``. `#5898 `_ - Exported ``HTTPMove`` which can be used to catch any redirection request that has a location -- :user:`dreamsorcerer`. `#6594 `_ - Changed the ``path`` parameter in ``web.run_app()`` to accept a ``pathlib.Path`` object. `#6839 `_ - Performance: Skipped filtering ``CookieJar`` when the jar is empty or all cookies have expired. `#7819 `_ - Performance: Only check origin if insecure scheme and there are origins to treat as secure, in ``CookieJar.filter_cookies()``. `#7821 `_ - Performance: Used timestamp instead of ``datetime`` to achieve faster cookie expiration in ``CookieJar``. `#7824 `_ - Added support for passing a custom server name parameter to HTTPS connection. `#7114 `_ - Added support for using Basic Auth credentials from :file:`.netrc` file when making HTTP requests with the :py:class:`~aiohttp.ClientSession` ``trust_env`` argument is set to ``True``. -- by :user:`yuvipanda`. `#7131 `_ - Turned access log into no-op when the logger is disabled. `#7240 `_ - Added typing information to ``RawResponseMessage``. -- by :user:`Gobot1234` `#7365 `_ - Removed ``async-timeout`` for Python 3.11+ (replaced with ``asyncio.timeout()`` on newer releases). `#7502 `_ - Added support for ``brotlicffi`` as an alternative to ``brotli`` (fixing Brotli support on PyPy). `#7611 `_ - Added ``WebSocketResponse.get_extra_info()`` to access a protocol transport's extra info. `#7078 `_ - Allow ``link`` argument to be set to None/empty in HTTP 451 exception. `#7689 `_ Bugfixes -------- - Implemented stripping the trailing dots from fully-qualified domain names in ``Host`` headers and TLS context when acting as an HTTP client. This allows the client to connect to URLs with FQDN host name like ``https://example.com./``. -- by :user:`martin-sucha`. `#3636 `_ - Fixed client timeout not working when incoming data is always available without waiting. -- by :user:`Dreamsorcerer`. `#5854 `_ - Fixed ``readuntil`` to work with a delimiter of more than one character. `#6701 `_ - Added ``__repr__`` to ``EmptyStreamReader`` to avoid ``AttributeError``. `#6916 `_ - Fixed bug when using ``TCPConnector`` with ``ttl_dns_cache=0``. `#7014 `_ - Fixed response returned from expect handler being thrown away. -- by :user:`Dreamsorcerer` `#7025 `_ - Avoided raising ``UnicodeDecodeError`` in multipart and in HTTP headers parsing. `#7044 `_ - Changed ``sock_read`` timeout to start after writing has finished, avoiding read timeouts caused by an unfinished write. -- by :user:`dtrifiro` `#7149 `_ - Fixed missing query in tracing method URLs when using ``yarl`` 1.9+. `#7259 `_ - Changed max 32-bit timestamp to an aware datetime object, for consistency with the non-32-bit one, and to avoid a ``DeprecationWarning`` on Python 3.12. `#7302 `_ - Fixed ``EmptyStreamReader.iter_chunks()`` never ending. -- by :user:`mind1m` `#7616 `_ - Fixed a rare ``RuntimeError: await wasn't used with future`` exception. -- by :user:`stalkerg` `#7785 `_ - Fixed issue with insufficient HTTP method and version validation. `#7700 `_ - Added check to validate that absolute URIs have schemes. `#7712 `_ - Fixed unhandled exception when Python HTTP parser encounters unpaired Unicode surrogates. `#7715 `_ - Updated parser to disallow invalid characters in header field names and stop accepting LF as a request line separator. `#7719 `_ - Fixed Python HTTP parser not treating 204/304/1xx as an empty body. `#7755 `_ - Ensure empty body response for 1xx/204/304 per RFC 9112 sec 6.3. `#7756 `_ - Fixed an issue when a client request is closed before completing a chunked payload. -- by :user:`Dreamsorcerer` `#7764 `_ - Edge Case Handling for ResponseParser for missing reason value. `#7776 `_ - Fixed ``ClientWebSocketResponse.close_code`` being erroneously set to ``None`` when there are concurrent async tasks receiving data and closing the connection. `#7306 `_ - Added HTTP method validation. `#6533 `_ - Fixed arbitrary sequence types being allowed to inject values via version parameter. -- by :user:`Dreamsorcerer` `#7835 `_ - Performance: Fixed increase in latency with small messages from websocket compression changes. `#7797 `_ Improved Documentation ---------------------- - Fixed the `ClientResponse.release`'s type in the doc. Changed from `comethod` to `method`. `#5836 `_ - Added information on behavior of base_url parameter in `ClientSession`. `#6647 `_ - Fixed `ClientResponseError` docs. `#6700 `_ - Updated Redis code examples to follow the latest API. `#6907 `_ - Added a note about possibly needing to update headers when using ``on_response_prepare``. -- by :user:`Dreamsorcerer` `#7283 `_ - Completed ``trust_env`` parameter description to honor ``wss_proxy``, ``ws_proxy`` or ``no_proxy`` env. `#7325 `_ - Expanded SSL documentation with more examples (e.g. how to use certifi). -- by :user:`Dreamsorcerer` `#7334 `_ - Fix, update, and improve client exceptions documentation. `#7733 `_ Deprecations and Removals ------------------------- - Added ``shutdown_timeout`` parameter to ``BaseRunner``, while deprecating ``shutdown_timeout`` parameter from ``BaseSite``. -- by :user:`Dreamsorcerer` `#7718 `_ - Dropped Python 3.6 support. `#6378 `_ - Dropped Python 3.7 support. -- by :user:`Dreamsorcerer` `#7336 `_ - Removed support for abandoned ``tokio`` event loop. -- by :user:`Dreamsorcerer` `#7281 `_ Misc ---- - Made ``print`` argument in ``run_app()`` optional. `#3690 `_ - Improved performance of ``ceil_timeout`` in some cases. `#6316 `_ - Changed importing Gunicorn to happen on-demand, decreasing import time by ~53%. -- :user:`Dreamsorcerer` `#6591 `_ - Improved import time by replacing ``http.server`` with ``http.HTTPStatus``. `#6903 `_ - Fixed annotation of ``ssl`` parameter to disallow ``True``. -- by :user:`Dreamsorcerer`. `#7335 `_ ---- 3.8.6 (2023-10-07) ================== Security bugfixes ----------------- - Upgraded the vendored copy of llhttp_ to v9.1.3 -- by :user:`Dreamsorcerer` Thanks to :user:`kenballus` for reporting this, see https://github.com/aio-libs/aiohttp/security/advisories/GHSA-pjjw-qhg8-p2p9. .. _llhttp: https://llhttp.org `#7647 `_ - Updated Python parser to comply with RFCs 9110/9112 -- by :user:`Dreamorcerer` Thanks to :user:`kenballus` for reporting this, see https://github.com/aio-libs/aiohttp/security/advisories/GHSA-gfw2-4jvh-wgfg. `#7663 `_ Deprecation ----------- - Added ``fallback_charset_resolver`` parameter in ``ClientSession`` to allow a user-supplied character set detection function. Character set detection will no longer be included in 3.9 as a default. If this feature is needed, please use `fallback_charset_resolver `_. `#7561 `_ Features -------- - Enabled lenient response parsing for more flexible parsing in the client (this should resolve some regressions when dealing with badly formatted HTTP responses). -- by :user:`Dreamsorcerer` `#7490 `_ Bugfixes -------- - Fixed ``PermissionError`` when ``.netrc`` is unreadable due to permissions. `#7237 `_ - Fixed output of parsing errors pointing to a ``\n``. -- by :user:`Dreamsorcerer` `#7468 `_ - Fixed ``GunicornWebWorker`` max_requests_jitter not working. `#7518 `_ - Fixed sorting in ``filter_cookies`` to use cookie with longest path. -- by :user:`marq24`. `#7577 `_ - Fixed display of ``BadStatusLine`` messages from llhttp_. -- by :user:`Dreamsorcerer` `#7651 `_ ---- 3.8.5 (2023-07-19) ================== Security bugfixes ----------------- - Upgraded the vendored copy of llhttp_ to v8.1.1 -- by :user:`webknjaz` and :user:`Dreamsorcerer`. Thanks to :user:`sethmlarson` for reporting this and providing us with comprehensive reproducer, workarounds and fixing details! For more information, see https://github.com/aio-libs/aiohttp/security/advisories/GHSA-45c4-8wx5-qw6w. .. _llhttp: https://llhttp.org `#7346 `_ Features -------- - Added information to C parser exceptions to show which character caused the error. -- by :user:`Dreamsorcerer` `#7366 `_ Bugfixes -------- - Fixed a transport is :data:`None` error -- by :user:`Dreamsorcerer`. `#3355 `_ ---- 3.8.4 (2023-02-12) ================== Bugfixes -------- - Fixed incorrectly overwriting cookies with the same name and domain, but different path. `#6638 `_ - Fixed ``ConnectionResetError`` not being raised after client disconnection in SSL environments. `#7180 `_ ---- 3.8.3 (2022-09-21) ================== .. attention:: This is the last :doc:`aiohttp ` release tested under Python 3.6. The 3.9 stream is dropping it from the CI and the distribution package metadata. Bugfixes -------- - Increased the upper boundary of the :doc:`multidict:index` dependency to allow for the version 6 -- by :user:`hugovk`. It used to be limited below version 7 in :doc:`aiohttp ` v3.8.1 but was lowered in v3.8.2 via :pr:`6550` and never brought back, causing problems with dependency pins when upgrading. :doc:`aiohttp ` v3.8.3 fixes that by recovering the original boundary of ``< 7``. `#6950 `_ ---- 3.8.2 (2022-09-20, subsequently yanked on 2022-09-21) ===================================================== Bugfixes -------- - Support registering OPTIONS HTTP method handlers via RouteTableDef. `#4663 `_ - Started supporting ``authority-form`` and ``absolute-form`` URLs on the server-side. `#6227 `_ - Fix Python 3.11 alpha incompatibilities by using Cython 0.29.25 `#6396 `_ - Remove a deprecated usage of pytest.warns(None) `#6663 `_ - Fix regression where ``asyncio.CancelledError`` occurs on client disconnection. `#6719 `_ - Export :py:class:`~aiohttp.web.PrefixedSubAppResource` under :py:mod:`aiohttp.web` -- by :user:`Dreamsorcerer`. This fixes a regression introduced by :pr:`3469`. `#6889 `_ - Dropped the :class:`object` type possibility from the :py:attr:`aiohttp.ClientSession.timeout` property return type declaration. `#6917 `_, `#6923 `_ Improved Documentation ---------------------- - Added clarification on configuring the app object with settings such as a db connection. `#4137 `_ - Edited the web.run_app declaration. `#6401 `_ - Dropped the :class:`object` type possibility from the :py:attr:`aiohttp.ClientSession.timeout` property return type declaration. `#6917 `_, `#6923 `_ Deprecations and Removals ------------------------- - Drop Python 3.5 support, aiohttp works on 3.6+ now. `#4046 `_ Misc ---- - `#6369 `_, `#6399 `_, `#6550 `_, `#6708 `_, `#6757 `_, `#6857 `_, `#6872 `_ ---- 3.8.1 (2021-11-14) ================== Bugfixes -------- - Fix the error in handling the return value of `getaddrinfo`. `getaddrinfo` will return an `(int, bytes)` tuple, if CPython could not handle the address family. It will cause an index out of range error in aiohttp. For example, if user compile CPython with `--disable-ipv6` option, but his system enable the ipv6. `#5901 `_ - Do not install "examples" as a top-level package. `#6189 `_ - Restored ability to connect IPv6-only host. `#6195 `_ - Remove ``Signal`` from ``__all__``, replace ``aiohttp.Signal`` with ``aiosignal.Signal`` in docs `#6201 `_ - Made chunked encoding HTTP header check stricter. `#6305 `_ Improved Documentation ---------------------- - update quick starter demo codes. `#6240 `_ - Added an explanation of how tiny timeouts affect performance to the client reference document. `#6274 `_ - Add flake8-docstrings to flake8 configuration, enable subset of checks. `#6276 `_ - Added information on running complex applications with additional tasks/processes -- :user:`Dreamsorcerer`. `#6278 `_ Misc ---- - `#6205 `_ ---- 3.8.0 (2021-10-31) ================== Features -------- - Added a ``GunicornWebWorker`` feature for extending the aiohttp server configuration by allowing the 'wsgi' coroutine to return ``web.AppRunner`` object. `#2988 `_ - Switch from ``http-parser`` to ``llhttp`` `#3561 `_ - Use Brotli instead of brotlipy `#3803 `_ - Disable implicit switch-back to pure python mode. The build fails loudly if aiohttp cannot be compiled with C Accelerators. Use AIOHTTP_NO_EXTENSIONS=1 to explicitly disable C Extensions complication and switch to Pure-Python mode. Note that Pure-Python mode is significantly slower than compiled one. `#3828 `_ - Make access log use local time with timezone `#3853 `_ - Implemented ``readuntil`` in ``StreamResponse`` `#4054 `_ - FileResponse now supports ETag. `#4594 `_ - Add a request handler type alias ``aiohttp.typedefs.Handler``. `#4686 `_ - ``AioHTTPTestCase`` is more async friendly now. For people who use unittest and are used to use :py:exc:`~unittest.TestCase` it will be easier to write new test cases like the sync version of the :py:exc:`~unittest.TestCase` class, without using the decorator `@unittest_run_loop`, just `async def test_*`. The only difference is that for the people using python3.7 and below a new dependency is needed, it is ``asynctestcase``. `#4700 `_ - Add validation of HTTP header keys and values to prevent header injection. `#4818 `_ - Add predicate to ``AbstractCookieJar.clear``. Add ``AbstractCookieJar.clear_domain`` to clean all domain and subdomains cookies only. `#4942 `_ - Add keepalive_timeout parameter to web.run_app. `#5094 `_ - Tracing for client sent headers `#5105 `_ - Make type hints for http parser stricter `#5267 `_ - Add final declarations for constants. `#5275 `_ - Switch to external frozenlist and aiosignal libraries. `#5293 `_ - Don't send secure cookies by insecure transports. By default, the transport is secure if https or wss scheme is used. Use `CookieJar(treat_as_secure_origin="http://127.0.0.1")` to override the default security checker. `#5571 `_ - Always create a new event loop in ``aiohttp.web.run_app()``. This adds better compatibility with ``asyncio.run()`` or if trying to run multiple apps in sequence. `#5572 `_ - Add ``aiohttp.pytest_plugin.AiohttpClient`` for static typing of pytest plugin. `#5585 `_ - Added a ``socket_factory`` argument to ``BaseTestServer``. `#5844 `_ - Add compression strategy parameter to enable_compression method. `#5909 `_ - Added support for Python 3.10 to Github Actions CI/CD workflows and fix the related deprecation warnings -- :user:`Hanaasagi`. `#5927 `_ - Switched ``chardet`` to ``charset-normalizer`` for guessing the HTTP payload body encoding -- :user:`Ousret`. `#5930 `_ - Added optional auto_decompress argument for HttpRequestParser `#5957 `_ - Added support for HTTPS proxies to the extent CPython's :py:mod:`asyncio` supports it -- by :user:`bmbouter`, :user:`jborean93` and :user:`webknjaz`. `#5992 `_ - Added ``base_url`` parameter to the initializer of :class:`~aiohttp.ClientSession`. `#6013 `_ - Add Trove classifier and create binary wheels for 3.10. -- :user:`hugovk`. `#6079 `_ - Started shipping platform-specific wheels with the ``musl`` tag targeting typical Alpine Linux runtimes — :user:`asvetlov`. `#6139 `_ - Started shipping platform-specific arm64 wheels for Apple Silicon — :user:`asvetlov`. `#6139 `_ Bugfixes -------- - Modify _drain_helper() to handle concurrent `await resp.write(...)` or `ws.send_json(...)` calls without race-condition. `#2934 `_ - Started using `MultiLoopChildWatcher` when it's available under POSIX while setting up the test I/O loop. `#3450 `_ - Only encode content-disposition filename parameter using percent-encoding. Other parameters are encoded to quoted-string or RFC2231 extended parameter value. `#4012 `_ - Fixed HTTP client requests to honor ``no_proxy`` environment variables. `#4431 `_ - Fix supporting WebSockets proxies configured via environment variables. `#4648 `_ - Change return type on URLDispatcher to UrlMappingMatchInfo to improve type annotations. `#4748 `_ - Ensure a cleanup context is cleaned up even when an exception occurs during startup. `#4799 `_ - Added a new exception type for Unix socket client errors which provides a more useful error message. `#4984 `_ - Remove Transfer-Encoding and Content-Type headers for 204 in StreamResponse `#5106 `_ - Only depend on typing_extensions for Python <3.8 `#5107 `_ - Add ABNORMAL_CLOSURE and BAD_GATEWAY to WSCloseCode `#5192 `_ - Fix cookies disappearing from HTTPExceptions. `#5233 `_ - StaticResource prefixes no longer match URLs with a non-folder prefix. For example ``routes.static('/foo', '/foo')`` no longer matches the URL ``/foobar``. Previously, this would attempt to load the file ``/foo/ar``. `#5250 `_ - Acquire the connection before running traces to prevent race condition. `#5259 `_ - Add missing slots to ```_RequestContextManager`` and ``_WSRequestContextManager`` `#5329 `_ - Ensure sending a zero byte file does not throw an exception (round 2) `#5380 `_ - Set "text/plain" when data is an empty string in client requests. `#5392 `_ - Stop automatically releasing the ``ClientResponse`` object on calls to the ``ok`` property for the failed requests. `#5403 `_ - Include query parameters from `params` keyword argument in tracing `URL`. `#5432 `_ - Fix annotations `#5466 `_ - Fixed the multipart POST requests processing to always release file descriptors for the ``tempfile.Temporaryfile``-created ``_io.BufferedRandom`` instances of files sent within multipart request bodies via HTTP POST requests -- by :user:`webknjaz`. `#5494 `_ - Fix 0 being incorrectly treated as an immediate timeout. `#5527 `_ - Fixes failing tests when an environment variable _proxy is set. `#5554 `_ - Replace deprecated app handler design in ``tests/autobahn/server.py`` with call to ``web.run_app``; replace deprecated ``aiohttp.ws_connect`` calls in ``tests/autobahn/client.py`` with ``aiohttp.ClienSession.ws_connect``. `#5606 `_ - Fixed test for ``HTTPUnauthorized`` that access the ``text`` argument. This is not used in any part of the code, so it's removed now. `#5657 `_ - Remove incorrect default from docs `#5727 `_ - Remove external test dependency to http://httpbin.org `#5840 `_ - Don't cancel current task when entering a cancelled timer. `#5853 `_ - Added ``params`` keyword argument to ``ClientSession.ws_connect``. -- :user:`hoh`. `#5868 `_ - Uses ``asyncio.ThreadedChildWatcher`` under POSIX to allow setting up test loop in non-main thread. `#5877 `_ - Fix the error in handling the return value of `getaddrinfo`. `getaddrinfo` will return an `(int, bytes)` tuple, if CPython could not handle the address family. It will cause a index out of range error in aiohttp. For example, if user compile CPython with `--disable-ipv6` option but his system enable the ipv6. `#5901 `_ - Removed the deprecated ``loop`` argument from the ``asyncio.sleep``/``gather`` calls `#5905 `_ - Return ``None`` from ``request.if_modified_since``, ``request.if_unmodified_since``, ``request.if_range`` and ``response.last_modified`` when corresponding http date headers are invalid. `#5925 `_ - Fix resetting `SIGCHLD` signals in Gunicorn aiohttp Worker to fix `subprocesses` that capture output having an incorrect `returncode`. `#6130 `_ - Raise ``400: Content-Length can't be present with Transfer-Encoding`` if both ``Content-Length`` and ``Transfer-Encoding`` are sent by peer by both C and Python implementations `#6182 `_ Improved Documentation ---------------------- - Refactored OpenAPI/Swagger aiohttp addons, added ``aio-openapi`` `#5326 `_ - Fixed docs on request cookies type, so it matches what is actually used in the code (a read-only dictionary-like object). `#5725 `_ - Documented that the HTTP client ``Authorization`` header is removed on redirects to a different host or protocol. `#5850 `_ Misc ---- - `#3927 `_, `#4247 `_, `#4247 `_, `#5389 `_, `#5457 `_, `#5486 `_, `#5494 `_, `#5515 `_, `#5625 `_, `#5635 `_, `#5648 `_, `#5657 `_, `#5890 `_, `#5914 `_, `#5932 `_, `#6002 `_, `#6045 `_, `#6131 `_, `#6156 `_, `#6165 `_, `#6166 `_ ---- 3.7.4.post0 (2021-03-06) ======================== Misc ---- - Bumped upper bound of the ``chardet`` runtime dependency to allow their v4.0 version stream. `#5366 `_ ---- 3.7.4 (2021-02-25) ================== Bugfixes -------- - **(SECURITY BUG)** Started preventing open redirects in the ``aiohttp.web.normalize_path_middleware`` middleware. For more details, see https://github.com/aio-libs/aiohttp/security/advisories/GHSA-v6wp-4m6f-gcjg. Thanks to `Beast Glatisant `__ for finding the first instance of this issue and `Jelmer Vernooij `__ for reporting and tracking it down in aiohttp. `#5497 `_ - Fix interpretation difference of the pure-Python and the Cython-based HTTP parsers construct a ``yarl.URL`` object for HTTP request-target. Before this fix, the Python parser would turn the URI's absolute-path for ``//some-path`` into ``/`` while the Cython code preserved it as ``//some-path``. Now, both do the latter. `#5498 `_ ---- 3.7.3 (2020-11-18) ================== Features -------- - Use Brotli instead of brotlipy `#3803 `_ - Made exceptions pickleable. Also changed the repr of some exceptions. `#4077 `_ Bugfixes -------- - Raise a ClientResponseError instead of an AssertionError for a blank HTTP Reason Phrase. `#3532 `_ - Fix ``web_middlewares.normalize_path_middleware`` behavior for patch without slash. `#3669 `_ - Fix overshadowing of overlapped sub-applications prefixes. `#3701 `_ - Make `BaseConnector.close()` a coroutine and wait until the client closes all connections. Drop deprecated "with Connector():" syntax. `#3736 `_ - Reset the ``sock_read`` timeout each time data is received for a ``aiohttp.client`` response. `#3808 `_ - Fixed type annotation for add_view method of UrlDispatcher to accept any subclass of View `#3880 `_ - Fixed querying the address families from DNS that the current host supports. `#5156 `_ - Change return type of MultipartReader.__aiter__() and BodyPartReader.__aiter__() to AsyncIterator. `#5163 `_ - Provide x86 Windows wheels. `#5230 `_ Improved Documentation ---------------------- - Add documentation for ``aiohttp.web.FileResponse``. `#3958 `_ - Removed deprecation warning in tracing example docs `#3964 `_ - Fixed wrong "Usage" docstring of ``aiohttp.client.request``. `#4603 `_ - Add aiohttp-pydantic to third party libraries `#5228 `_ Misc ---- - `#4102 `_ ---- 3.7.2 (2020-10-27) ================== Bugfixes -------- - Fixed static files handling for loops without ``.sendfile()`` support `#5149 `_ ---- 3.7.1 (2020-10-25) ================== Bugfixes -------- - Fixed a type error caused by the conditional import of `Protocol`. `#5111 `_ - Server doesn't send Content-Length for 1xx or 204 `#4901 `_ - Fix run_app typing `#4957 `_ - Always require ``typing_extensions`` library. `#5107 `_ - Fix a variable-shadowing bug causing `ThreadedResolver.resolve` to return the resolved IP as the ``hostname`` in each record, which prevented validation of HTTPS connections. `#5110 `_ - Added annotations to all public attributes. `#5115 `_ - Fix flaky test_when_timeout_smaller_second `#5116 `_ - Ensure sending a zero byte file does not throw an exception `#5124 `_ - Fix a bug in ``web.run_app()`` about Python version checking on Windows `#5127 `_ ---- 3.7.0 (2020-10-24) ================== Features -------- - Response headers are now prepared prior to running ``on_response_prepare`` hooks, directly before headers are sent to the client. `#1958 `_ - Add a ``quote_cookie`` option to ``CookieJar``, a way to skip quotation wrapping of cookies containing special characters. `#2571 `_ - Call ``AccessLogger.log`` with the current exception available from ``sys.exc_info()``. `#3557 `_ - `web.UrlDispatcher.add_routes` and `web.Application.add_routes` return a list of registered `AbstractRoute` instances. `AbstractRouteDef.register` (and all subclasses) return a list of registered resources registered resource. `#3866 `_ - Added properties of default ClientSession params to ClientSession class so it is available for introspection `#3882 `_ - Don't cancel web handler on peer disconnection, raise `OSError` on reading/writing instead. `#4080 `_ - Implement BaseRequest.get_extra_info() to access a protocol transports' extra info. `#4189 `_ - Added `ClientSession.timeout` property. `#4191 `_ - allow use of SameSite in cookies. `#4224 `_ - Use ``loop.sendfile()`` instead of custom implementation if available. `#4269 `_ - Apply SO_REUSEADDR to test server's socket. `#4393 `_ - Use .raw_host instead of slower .host in client API `#4402 `_ - Allow configuring the buffer size of input stream by passing ``read_bufsize`` argument. `#4453 `_ - Pass tests on Python 3.8 for Windows. `#4513 `_ - Add `method` and `url` attributes to `TraceRequestChunkSentParams` and `TraceResponseChunkReceivedParams`. `#4674 `_ - Add ClientResponse.ok property for checking status code under 400. `#4711 `_ - Don't ceil timeouts that are smaller than 5 seconds. `#4850 `_ - TCPSite now listens by default on all interfaces instead of just IPv4 when `None` is passed in as the host. `#4894 `_ - Bump ``http_parser`` to 2.9.4 `#5070 `_ Bugfixes -------- - Fix keepalive connections not being closed in time `#3296 `_ - Fix failed websocket handshake leaving connection hanging. `#3380 `_ - Fix tasks cancellation order on exit. The run_app task needs to be cancelled first for cleanup hooks to run with all tasks intact. `#3805 `_ - Don't start heartbeat until _writer is set `#4062 `_ - Fix handling of multipart file uploads without a content type. `#4089 `_ - Preserve view handler function attributes across middlewares `#4174 `_ - Fix the string representation of ``ServerDisconnectedError``. `#4175 `_ - Raising RuntimeError when trying to get encoding from not read body `#4214 `_ - Remove warning messages from noop. `#4282 `_ - Raise ClientPayloadError if FormData re-processed. `#4345 `_ - Fix a warning about unfinished task in ``web_protocol.py`` `#4408 `_ - Fixed 'deflate' compression. According to RFC 2616 now. `#4506 `_ - Fixed OverflowError on platforms with 32-bit time_t `#4515 `_ - Fixed request.body_exists returns wrong value for methods without body. `#4528 `_ - Fix connecting to link-local IPv6 addresses. `#4554 `_ - Fix a problem with connection waiters that are never awaited. `#4562 `_ - Always make sure transport is not closing before reuse a connection. Reuse a protocol based on keepalive in headers is unreliable. For example, uWSGI will not support keepalive even it serves a HTTP 1.1 request, except explicitly configure uWSGI with a ``--http-keepalive`` option. Servers designed like uWSGI could cause aiohttp intermittently raise a ConnectionResetException when the protocol poll runs out and some protocol is reused. `#4587 `_ - Handle the last CRLF correctly even if it is received via separate TCP segment. `#4630 `_ - Fix the register_resource function to validate route name before splitting it so that route name can include python keywords. `#4691 `_ - Improve typing annotations for ``web.Request``, ``aiohttp.ClientResponse`` and ``multipart`` module. `#4736 `_ - Fix resolver task is not awaited when connector is cancelled `#4795 `_ - Fix a bug "Aiohttp doesn't return any error on invalid request methods" `#4798 `_ - Fix HEAD requests for static content. `#4809 `_ - Fix incorrect size calculation for memoryview `#4890 `_ - Add HTTPMove to _all__. `#4897 `_ - Fixed the type annotations in the ``tracing`` module. `#4912 `_ - Fix typing for multipart ``__aiter__``. `#4931 `_ - Fix for race condition on connections in BaseConnector that leads to exceeding the connection limit. `#4936 `_ - Add forced UTF-8 encoding for ``application/rdap+json`` responses. `#4938 `_ - Fix inconsistency between Python and C http request parsers in parsing pct-encoded URL. `#4972 `_ - Fix connection closing issue in HEAD request. `#5012 `_ - Fix type hint on BaseRunner.addresses (from ``List[str]`` to ``List[Any]``) `#5086 `_ - Make `web.run_app()` more responsive to Ctrl+C on Windows for Python < 3.8. It slightly increases CPU load as a side effect. `#5098 `_ Improved Documentation ---------------------- - Fix example code in client quick-start `#3376 `_ - Updated the docs so there is no contradiction in ``ttl_dns_cache`` default value `#3512 `_ - Add 'Deploy with SSL' to docs. `#4201 `_ - Change typing of the secure argument on StreamResponse.set_cookie from ``Optional[str]`` to ``Optional[bool]`` `#4204 `_ - Changes ``ttl_dns_cache`` type from int to Optional[int]. `#4270 `_ - Simplify README hello word example and add a documentation page for people coming from requests. `#4272 `_ - Improve some code examples in the documentation involving websockets and starting a simple HTTP site with an AppRunner. `#4285 `_ - Fix typo in code example in Multipart docs `#4312 `_ - Fix code example in Multipart section. `#4314 `_ - Update contributing guide so new contributors read the most recent version of that guide. Update command used to create test coverage reporting. `#4810 `_ - Spelling: Change "canonize" to "canonicalize". `#4986 `_ - Add ``aiohttp-sse-client`` library to third party usage list. `#5084 `_ Misc ---- - `#2856 `_, `#4218 `_, `#4250 `_ ---- 3.6.3 (2020-10-12) ================== Bugfixes -------- - Pin yarl to ``<1.6.0`` to avoid buggy behavior that will be fixed by the next aiohttp release. 3.6.2 (2019-10-09) ================== Features -------- - Made exceptions pickleable. Also changed the repr of some exceptions. `#4077 `_ - Use ``Iterable`` type hint instead of ``Sequence`` for ``Application`` *middleware* parameter. `#4125 `_ Bugfixes -------- - Reset the ``sock_read`` timeout each time data is received for a ``aiohttp.ClientResponse``. `#3808 `_ - Fix handling of expired cookies so they are not stored in CookieJar. `#4063 `_ - Fix misleading message in the string representation of ``ClientConnectorError``; ``self.ssl == None`` means default SSL context, not SSL disabled `#4097 `_ - Don't clobber HTTP status when using FileResponse. `#4106 `_ Improved Documentation ---------------------- - Added minimal required logging configuration to logging documentation. `#2469 `_ - Update docs to reflect proxy support. `#4100 `_ - Fix typo in code example in testing docs. `#4108 `_ Misc ---- - `#4102 `_ ---- 3.6.1 (2019-09-19) ================== Features -------- - Compatibility with Python 3.8. `#4056 `_ Bugfixes -------- - correct some exception string format `#4068 `_ - Emit a warning when ``ssl.OP_NO_COMPRESSION`` is unavailable because the runtime is built against an outdated OpenSSL. `#4052 `_ - Update multidict requirement to >= 4.5 `#4057 `_ Improved Documentation ---------------------- - Provide pytest-aiohttp namespace for pytest fixtures in docs. `#3723 `_ ---- 3.6.0 (2019-09-06) ================== Features -------- - Add support for Named Pipes (Site and Connector) under Windows. This feature requires Proactor event loop to work. `#3629 `_ - Removed ``Transfer-Encoding: chunked`` header from websocket responses to be compatible with more http proxy servers. `#3798 `_ - Accept non-GET request for starting websocket handshake on server side. `#3980 `_ Bugfixes -------- - Raise a ClientResponseError instead of an AssertionError for a blank HTTP Reason Phrase. `#3532 `_ - Fix an issue where cookies would sometimes not be set during a redirect. `#3576 `_ - Change normalize_path_middleware to use '308 Permanent Redirect' instead of 301. This behavior should prevent clients from being unable to use PUT/POST methods on endpoints that are redirected because of a trailing slash. `#3579 `_ - Drop the processed task from ``all_tasks()`` list early. It prevents logging about a task with unhandled exception when the server is used in conjunction with ``asyncio.run()``. `#3587 `_ - ``Signal`` type annotation changed from ``Signal[Callable[['TraceConfig'], Awaitable[None]]]`` to ``Signal[Callable[ClientSession, SimpleNamespace, ...]``. `#3595 `_ - Use sanitized URL as Location header in redirects `#3614 `_ - Improve typing annotations for multipart.py along with changes required by mypy in files that references multipart.py. `#3621 `_ - Close session created inside ``aiohttp.request`` when unhandled exception occurs `#3628 `_ - Cleanup per-chunk data in generic data read. Memory leak fixed. `#3631 `_ - Use correct type for add_view and family `#3633 `_ - Fix _keepalive field in __slots__ of ``RequestHandler``. `#3644 `_ - Properly handle ConnectionResetError, to silence the "Cannot write to closing transport" exception when clients disconnect uncleanly. `#3648 `_ - Suppress pytest warnings due to ``test_utils`` classes `#3660 `_ - Fix overshadowing of overlapped sub-application prefixes. `#3701 `_ - Fixed return type annotation for WSMessage.json() `#3720 `_ - Properly expose TooManyRedirects publicly as documented. `#3818 `_ - Fix missing brackets for IPv6 in proxy CONNECT request `#3841 `_ - Make the signature of ``aiohttp.test_utils.TestClient.request`` match ``asyncio.ClientSession.request`` according to the docs `#3852 `_ - Use correct style for re-exported imports, makes mypy ``--strict`` mode happy. `#3868 `_ - Fixed type annotation for add_view method of UrlDispatcher to accept any subclass of View `#3880 `_ - Made cython HTTP parser set Reason-Phrase of the response to an empty string if it is missing. `#3906 `_ - Add URL to the string representation of ClientResponseError. `#3959 `_ - Accept ``istr`` keys in ``LooseHeaders`` type hints. `#3976 `_ - Fixed race conditions in _resolve_host caching and throttling when tracing is enabled. `#4013 `_ - For URLs like "unix://localhost/..." set Host HTTP header to "localhost" instead of "localhost:None". `#4039 `_ Improved Documentation ---------------------- - Modify documentation for Background Tasks to remove deprecated usage of event loop. `#3526 `_ - use ``if __name__ == '__main__':`` in server examples. `#3775 `_ - Update documentation reference to the default access logger. `#3783 `_ - Improve documentation for ``web.BaseRequest.path`` and ``web.BaseRequest.raw_path``. `#3791 `_ - Removed deprecation warning in tracing example docs `#3964 `_ ---- 3.5.4 (2019-01-12) ================== Bugfixes -------- - Fix stream ``.read()`` / ``.readany()`` / ``.iter_any()`` which used to return a partial content only in case of compressed content `#3525 `_ 3.5.3 (2019-01-10) ================== Bugfixes -------- - Fix type stubs for ``aiohttp.web.run_app(access_log=True)`` and fix edge case of ``access_log=True`` and the event loop being in debug mode. `#3504 `_ - Fix ``aiohttp.ClientTimeout`` type annotations to accept ``None`` for fields `#3511 `_ - Send custom per-request cookies even if session jar is empty `#3515 `_ - Restore Linux binary wheels publishing on PyPI ---- 3.5.2 (2019-01-08) ================== Features -------- - ``FileResponse`` from ``web_fileresponse.py`` uses a ``ThreadPoolExecutor`` to work with files asynchronously. I/O based payloads from ``payload.py`` uses a ``ThreadPoolExecutor`` to work with I/O objects asynchronously. `#3313 `_ - Internal Server Errors in plain text if the browser does not support HTML. `#3483 `_ Bugfixes -------- - Preserve MultipartWriter parts headers on write. Refactor the way how ``Payload.headers`` are handled. Payload instances now always have headers and Content-Type defined. Fix Payload Content-Disposition header reset after initial creation. `#3035 `_ - Log suppressed exceptions in ``GunicornWebWorker``. `#3464 `_ - Remove wildcard imports. `#3468 `_ - Use the same task for app initialization and web server handling in gunicorn workers. It allows to use Python3.7 context vars smoothly. `#3471 `_ - Fix handling of chunked+gzipped response when first chunk does not give uncompressed data `#3477 `_ - Replace ``collections.MutableMapping`` with ``collections.abc.MutableMapping`` to avoid a deprecation warning. `#3480 `_ - ``Payload.size`` type annotation changed from ``Optional[float]`` to ``Optional[int]``. `#3484 `_ - Ignore done tasks when cancels pending activities on ``web.run_app`` finalization. `#3497 `_ Improved Documentation ---------------------- - Add documentation for ``aiohttp.web.HTTPException``. `#3490 `_ Misc ---- - `#3487 `_ ---- 3.5.1 (2018-12-24) ==================== - Fix a regression about ``ClientSession._requote_redirect_url`` modification in debug mode. 3.5.0 (2018-12-22) ==================== Features -------- - The library type annotations are checked in strict mode now. - Add support for setting cookies for individual request (`#2387 `_) - Application.add_domain implementation (`#2809 `_) - The default ``app`` in the request returned by ``test_utils.make_mocked_request`` can now have objects assigned to it and retrieved using the ``[]`` operator. (`#3174 `_) - Make ``request.url`` accessible when transport is closed. (`#3177 `_) - Add ``zlib_executor_size`` argument to ``Response`` constructor to allow compression to run in a background executor to avoid blocking the main thread and potentially triggering health check failures. (`#3205 `_) - Enable users to set ``ClientTimeout`` in ``aiohttp.request`` (`#3213 `_) - Don't raise a warning if ``NETRC`` environment variable is not set and ``~/.netrc`` file doesn't exist. (`#3267 `_) - Add default logging handler to web.run_app If the ``Application.debug``` flag is set and the default logger ``aiohttp.access`` is used, access logs will now be output using a *stderr* ``StreamHandler`` if no handlers are attached. Furthermore, if the default logger has no log level set, the log level will be set to ``DEBUG``. (`#3324 `_) - Add method argument to ``session.ws_connect()``. Sometimes server API requires a different HTTP method for WebSocket connection establishment. For example, ``Docker exec`` needs POST. (`#3378 `_) - Create a task per request handling. (`#3406 `_) Bugfixes -------- - Enable passing ``access_log_class`` via ``handler_args`` (`#3158 `_) - Return empty bytes with end-of-chunk marker in empty stream reader. (`#3186 `_) - Accept ``CIMultiDictProxy`` instances for ``headers`` argument in ``web.Response`` constructor. (`#3207 `_) - Don't uppercase HTTP method in parser (`#3233 `_) - Make method match regexp RFC-7230 compliant (`#3235 `_) - Add ``app.pre_frozen`` state to properly handle startup signals in sub-applications. (`#3237 `_) - Enhanced parsing and validation of helpers.BasicAuth.decode. (`#3239 `_) - Change imports from collections module in preparation for 3.8. (`#3258 `_) - Ensure Host header is added first to ClientRequest to better replicate browser (`#3265 `_) - Fix forward compatibility with Python 3.8: importing ABCs directly from the collections module will not be supported anymore. (`#3273 `_) - Keep the query string by ``normalize_path_middleware``. (`#3278 `_) - Fix missing parameter ``raise_for_status`` for aiohttp.request() (`#3290 `_) - Bracket IPv6 addresses in the HOST header (`#3304 `_) - Fix default message for server ping and pong frames. (`#3308 `_) - Fix tests/test_connector.py typo and tests/autobahn/server.py duplicate loop def. (`#3337 `_) - Fix false-negative indicator end_of_HTTP_chunk in StreamReader.readchunk function (`#3361 `_) - Release HTTP response before raising status exception (`#3364 `_) - Fix task cancellation when ``sendfile()`` syscall is used by static file handling. (`#3383 `_) - Fix stack trace for ``asyncio.TimeoutError`` which was not logged, when it is caught in the handler. (`#3414 `_) Improved Documentation ---------------------- - Improve documentation of ``Application.make_handler`` parameters. (`#3152 `_) - Fix BaseRequest.raw_headers doc. (`#3215 `_) - Fix typo in TypeError exception reason in ``web.Application._handle`` (`#3229 `_) - Make server access log format placeholder %b documentation reflect behavior and docstring. (`#3307 `_) Deprecations and Removals ------------------------- - Deprecate modification of ``session.requote_redirect_url`` (`#2278 `_) - Deprecate ``stream.unread_data()`` (`#3260 `_) - Deprecated use of boolean in ``resp.enable_compression()`` (`#3318 `_) - Encourage creation of aiohttp public objects inside a coroutine (`#3331 `_) - Drop dead ``Connection.detach()`` and ``Connection.writer``. Both methods were broken for more than 2 years. (`#3358 `_) - Deprecate ``app.loop``, ``request.loop``, ``client.loop`` and ``connector.loop`` properties. (`#3374 `_) - Deprecate explicit debug argument. Use asyncio debug mode instead. (`#3381 `_) - Deprecate body parameter in HTTPException (and derived classes) constructor. (`#3385 `_) - Deprecate bare connector close, use ``async with connector:`` and ``await connector.close()`` instead. (`#3417 `_) - Deprecate obsolete ``read_timeout`` and ``conn_timeout`` in ``ClientSession`` constructor. (`#3438 `_) Misc ---- - #3341, #3351 ---- 3.4.4 (2018-09-05) ================== - Fix installation from sources when compiling toolkit is not available (`#3241 `_) ---- 3.4.3 (2018-09-04) ================== - Add ``app.pre_frozen`` state to properly handle startup signals in sub-applications. (`#3237 `_) ---- 3.4.2 (2018-09-01) ================== - Fix ``iter_chunks`` type annotation (`#3230 `_) ---- 3.4.1 (2018-08-28) ================== - Fix empty header parsing regression. (`#3218 `_) - Fix BaseRequest.raw_headers doc. (`#3215 `_) - Fix documentation building on ReadTheDocs (`#3221 `_) ---- 3.4.0 (2018-08-25) ================== Features -------- - Add type hints (`#3049 `_) - Add ``raise_for_status`` request parameter (`#3073 `_) - Add type hints to HTTP client (`#3092 `_) - Minor server optimizations (`#3095 `_) - Preserve the cause when `HTTPException` is raised from another exception. (`#3096 `_) - Add `close_boundary` option in `MultipartWriter.write` method. Support streaming (`#3104 `_) - Added a ``remove_slash`` option to the ``normalize_path_middleware`` factory. (`#3173 `_) - The class `AbstractRouteDef` is importable from `aiohttp.web`. (`#3183 `_) Bugfixes -------- - Prevent double closing when client connection is released before the last ``data_received()`` callback. (`#3031 `_) - Make redirect with `normalize_path_middleware` work when using url encoded paths. (`#3051 `_) - Postpone web task creation to connection establishment. (`#3052 `_) - Fix ``sock_read`` timeout. (`#3053 `_) - When using a server-request body as the `data=` argument of a client request, iterate over the content with `readany` instead of `readline` to avoid `Line too long` errors. (`#3054 `_) - fix `UrlDispatcher` has no attribute `add_options`, add `web.options` (`#3062 `_) - correct filename in content-disposition with multipart body (`#3064 `_) - Many HTTP proxies has buggy keepalive support. Let's not reuse connection but close it after processing every response. (`#3070 `_) - raise 413 "Payload Too Large" rather than raising ValueError in request.post() Add helpful debug message to 413 responses (`#3087 `_) - Fix `StreamResponse` equality, now that they are `MutableMapping` objects. (`#3100 `_) - Fix server request objects comparison (`#3116 `_) - Do not hang on `206 Partial Content` response with `Content-Encoding: gzip` (`#3123 `_) - Fix timeout precondition checkers (`#3145 `_) Improved Documentation ---------------------- - Add a new FAQ entry that clarifies that you should not reuse response objects in middleware functions. (`#3020 `_) - Add FAQ section "Why is creating a ClientSession outside of an event loop dangerous?" (`#3072 `_) - Fix link to Rambler (`#3115 `_) - Fix TCPSite documentation on the Server Reference page. (`#3146 `_) - Fix documentation build configuration file for Windows. (`#3147 `_) - Remove no longer existing lingering_timeout parameter of Application.make_handler from documentation. (`#3151 `_) - Mention that ``app.make_handler`` is deprecated, recommend to use runners API instead. (`#3157 `_) Deprecations and Removals ------------------------- - Drop ``loop.current_task()`` from ``helpers.current_task()`` (`#2826 `_) - Drop ``reader`` parameter from ``request.multipart()``. (`#3090 `_) ---- 3.3.2 (2018-06-12) ================== - Many HTTP proxies has buggy keepalive support. Let's not reuse connection but close it after processing every response. (`#3070 `_) - Provide vendor source files in tarball (`#3076 `_) ---- 3.3.1 (2018-06-05) ================== - Fix ``sock_read`` timeout. (`#3053 `_) - When using a server-request body as the ``data=`` argument of a client request, iterate over the content with ``readany`` instead of ``readline`` to avoid ``Line too long`` errors. (`#3054 `_) ---- 3.3.0 (2018-06-01) ================== Features -------- - Raise ``ConnectionResetError`` instead of ``CancelledError`` on trying to write to a closed stream. (`#2499 `_) - Implement ``ClientTimeout`` class and support socket read timeout. (`#2768 `_) - Enable logging when ``aiohttp.web`` is used as a program (`#2956 `_) - Add canonical property to resources (`#2968 `_) - Forbid reading response BODY after release (`#2983 `_) - Implement base protocol class to avoid a dependency from internal ``asyncio.streams.FlowControlMixin`` (`#2986 `_) - Cythonize ``@helpers.reify``, 5% boost on macro benchmark (`#2995 `_) - Optimize HTTP parser (`#3015 `_) - Implement ``runner.addresses`` property. (`#3036 `_) - Use ``bytearray`` instead of a list of ``bytes`` in websocket reader. It improves websocket message reading a little. (`#3039 `_) - Remove heartbeat on closing connection on keepalive timeout. The used hack violates HTTP protocol. (`#3041 `_) - Limit websocket message size on reading to 4 MB by default. (`#3045 `_) Bugfixes -------- - Don't reuse a connection with the same URL but different proxy/TLS settings (`#2981 `_) - When parsing the Forwarded header, the optional port number is now preserved. (`#3009 `_) Improved Documentation ---------------------- - Make Change Log more visible in docs (`#3029 `_) - Make style and grammar improvements on the FAQ page. (`#3030 `_) - Document that signal handlers should be async functions since aiohttp 3.0 (`#3032 `_) Deprecations and Removals ------------------------- - Deprecate custom application's router. (`#3021 `_) Misc ---- - #3008, #3011 ---- 3.2.1 (2018-05-10) ================== - Don't reuse a connection with the same URL but different proxy/TLS settings (`#2981 `_) ---- 3.2.0 (2018-05-06) ================== Features -------- - Raise ``TooManyRedirects`` exception when client gets redirected too many times instead of returning last response. (`#2631 `_) - Extract route definitions into separate ``web_routedef.py`` file (`#2876 `_) - Raise an exception on request body reading after sending response. (`#2895 `_) - ClientResponse and RequestInfo now have real_url property, which is request url without fragment part being stripped (`#2925 `_) - Speed up connector limiting (`#2937 `_) - Added and links property for ClientResponse object (`#2948 `_) - Add ``request.config_dict`` for exposing nested applications data. (`#2949 `_) - Speed up HTTP headers serialization, server micro-benchmark runs 5% faster now. (`#2957 `_) - Apply assertions in debug mode only (`#2966 `_) Bugfixes -------- - expose property `app` for TestClient (`#2891 `_) - Call on_chunk_sent when write_eof takes as a param the last chunk (`#2909 `_) - A closing bracket was added to `__repr__` of resources (`#2935 `_) - Fix compression of FileResponse (`#2942 `_) - Fixes some bugs in the limit connection feature (`#2964 `_) Improved Documentation ---------------------- - Drop ``async_timeout`` usage from documentation for client API in favor of ``timeout`` parameter. (`#2865 `_) - Improve Gunicorn logging documentation (`#2921 `_) - Replace multipart writer `.serialize()` method with `.write()` in documentation. (`#2965 `_) Deprecations and Removals ------------------------- - Deprecate Application.make_handler() (`#2938 `_) Misc ---- - #2958 ---- 3.1.3 (2018-04-12) ================== - Fix cancellation broadcast during DNS resolve (`#2910 `_) ---- 3.1.2 (2018-04-05) ================== - Make ``LineTooLong`` exception more detailed about actual data size (`#2863 `_) - Call ``on_chunk_sent`` when write_eof takes as a param the last chunk (`#2909 `_) ---- 3.1.1 (2018-03-27) ================== - Support *asynchronous iterators* (and *asynchronous generators* as well) in both client and server API as request / response BODY payloads. (`#2802 `_) ---- 3.1.0 (2018-03-21) ================== Welcome to aiohttp 3.1 release. This is an *incremental* release, fully backward compatible with *aiohttp 3.0*. But we have added several new features. The most visible one is ``app.add_routes()`` (an alias for existing ``app.router.add_routes()``. The addition is very important because all *aiohttp* docs now uses ``app.add_routes()`` call in code snippets. All your existing code still do register routes / resource without any warning but you've got the idea for a favorite way: noisy ``app.router.add_get()`` is replaced by ``app.add_routes()``. The library does not make a preference between decorators:: routes = web.RouteTableDef() @routes.get('/') async def hello(request): return web.Response(text="Hello, world") app.add_routes(routes) and route tables as a list:: async def hello(request): return web.Response(text="Hello, world") app.add_routes([web.get('/', hello)]) Both ways are equal, user may decide basing on own code taste. Also we have a lot of minor features, bug fixes and documentation updates, see below. Features -------- - Relax JSON content-type checking in the ``ClientResponse.json()`` to allow "application/xxx+json" instead of strict "application/json". (`#2206 `_) - Bump C HTTP parser to version 2.8 (`#2730 `_) - Accept a coroutine as an application factory in ``web.run_app`` and gunicorn worker. (`#2739 `_) - Implement application cleanup context (``app.cleanup_ctx`` property). (`#2747 `_) - Make ``writer.write_headers`` a coroutine. (`#2762 `_) - Add tracking signals for getting request/response bodies. (`#2767 `_) - Deprecate ClientResponseError.code in favor of .status to keep similarity with response classes. (`#2781 `_) - Implement ``app.add_routes()`` method. (`#2787 `_) - Implement ``web.static()`` and ``RouteTableDef.static()`` API. (`#2795 `_) - Install a test event loop as default by ``asyncio.set_event_loop()``. The change affects aiohttp test utils but backward compatibility is not broken for 99.99% of use cases. (`#2804 `_) - Refactor ``ClientResponse`` constructor: make logically required constructor arguments mandatory, drop ``_post_init()`` method. (`#2820 `_) - Use ``app.add_routes()`` in server docs everywhere (`#2830 `_) - Websockets refactoring, all websocket writer methods are converted into coroutines. (`#2836 `_) - Provide ``Content-Range`` header for ``Range`` requests (`#2844 `_) Bugfixes -------- - Fix websocket client return EofStream. (`#2784 `_) - Fix websocket demo. (`#2789 `_) - Property ``BaseRequest.http_range`` now returns a python-like slice when requesting the tail of the range. It's now indicated by a negative value in ``range.start`` rather then in ``range.stop`` (`#2805 `_) - Close a connection if an unexpected exception occurs while sending a request (`#2827 `_) - Fix firing DNS tracing events. (`#2841 `_) Improved Documentation ---------------------- - Document behavior when cchardet detects encodings that are unknown to Python. (`#2732 `_) - Add diagrams for tracing request life style. (`#2748 `_) - Drop removed functionality for passing ``StreamReader`` as data at client side. (`#2793 `_) ---- 3.0.9 (2018-03-14) ================== - Close a connection if an unexpected exception occurs while sending a request (`#2827 `_) ---- 3.0.8 (2018-03-12) ================== - Use ``asyncio.current_task()`` on Python 3.7 (`#2825 `_) ---- 3.0.7 (2018-03-08) ================== - Fix SSL proxy support by client. (`#2810 `_) - Restore an imperative check in ``setup.py`` for python version. The check works in parallel to environment marker. As effect an error about unsupported Python versions is raised even on outdated systems with very old ``setuptools`` version installed. (`#2813 `_) ---- 3.0.6 (2018-03-05) ================== - Add ``_reuse_address`` and ``_reuse_port`` to ``web_runner.TCPSite.__slots__``. (`#2792 `_) ---- 3.0.5 (2018-02-27) ================== - Fix ``InvalidStateError`` on processing a sequence of two ``RequestHandler.data_received`` calls on web server. (`#2773 `_) ---- 3.0.4 (2018-02-26) ================== - Fix ``IndexError`` in HTTP request handling by server. (`#2752 `_) - Fix MultipartWriter.append* no longer returning part/payload. (`#2759 `_) ---- 3.0.3 (2018-02-25) ================== - Relax ``attrs`` dependency to minimal actually supported version 17.0.3 The change allows to avoid version conflicts with currently existing test tools. ---- 3.0.2 (2018-02-23) ================== Security Fix ------------ - Prevent Windows absolute URLs in static files. Paths like ``/static/D:\path`` and ``/static/\\hostname\drive\path`` are forbidden. ---- 3.0.1 ===== - Technical release for fixing distribution problems. ---- 3.0.0 (2018-02-12) ================== Features -------- - Speed up the `PayloadWriter.write` method for large request bodies. (`#2126 `_) - StreamResponse and Response are now MutableMappings. (`#2246 `_) - ClientSession publishes a set of signals to track the HTTP request execution. (`#2313 `_) - Content-Disposition fast access in ClientResponse (`#2455 `_) - Added support to Flask-style decorators with class-based Views. (`#2472 `_) - Signal handlers (registered callbacks) should be coroutines. (`#2480 `_) - Support ``async with test_client.ws_connect(...)`` (`#2525 `_) - Introduce *site* and *application runner* as underlying API for `web.run_app` implementation. (`#2530 `_) - Only quote multipart boundary when necessary and sanitize input (`#2544 `_) - Make the `aiohttp.ClientResponse.get_encoding` method public with the processing of invalid charset while detecting content encoding. (`#2549 `_) - Add optional configurable per message compression for `ClientWebSocketResponse` and `WebSocketResponse`. (`#2551 `_) - Add hysteresis to `StreamReader` to prevent flipping between paused and resumed states too often. (`#2555 `_) - Support `.netrc` by `trust_env` (`#2581 `_) - Avoid to create a new resource when adding a route with the same name and path of the last added resource (`#2586 `_) - `MultipartWriter.boundary` is `str` now. (`#2589 `_) - Allow a custom port to be used by `TestServer` (and associated pytest fixtures) (`#2613 `_) - Add param access_log_class to web.run_app function (`#2615 `_) - Add ``ssl`` parameter to client API (`#2626 `_) - Fixes performance issue introduced by #2577. When there are no middlewares installed by the user, no additional and useless code is executed. (`#2629 `_) - Rename PayloadWriter to StreamWriter (`#2654 `_) - New options *reuse_port*, *reuse_address* are added to `run_app` and `TCPSite`. (`#2679 `_) - Use custom classes to pass client signals parameters (`#2686 `_) - Use ``attrs`` library for data classes, replace `namedtuple`. (`#2690 `_) - Pytest fixtures renaming, add ``aiohttp_`` prefix (`#2578 `_) - Add ``aiohttp-`` prefix for ``pytest-aiohttp`` command line parameters (`#2578 `_) Bugfixes -------- - Correctly process upgrade request from server to HTTP2. ``aiohttp`` does not support HTTP2 yet, the protocol is not upgraded but response is handled correctly. (`#2277 `_) - Fix ClientConnectorSSLError and ClientProxyConnectionError for proxy connector (`#2408 `_) - Fix connector convert OSError to ClientConnectorError (`#2423 `_) - Fix connection attempts for multiple dns hosts (`#2424 `_) - Fix writing to closed transport by raising `asyncio.CancelledError` (`#2499 `_) - Fix warning in `ClientSession.__del__` by stopping to try to close it. (`#2523 `_) - Fixed race-condition for iterating addresses from the DNSCache. (`#2620 `_) - Fix default value of `access_log_format` argument in `web.run_app` (`#2649 `_) - Freeze sub-application on adding to parent app (`#2656 `_) - Do percent encoding for `.url_for()` parameters (`#2668 `_) - Correctly process request start time and multiple request/response headers in access log extra (`#2641 `_) Improved Documentation ---------------------- - Improve tutorial docs, using `literalinclude` to link to the actual files. (`#2396 `_) - Small improvement docs: better example for file uploads. (`#2401 `_) - Rename `from_env` to `trust_env` in client reference. (`#2451 `_) - Fixed mistype in `Proxy Support` section where `trust_env` parameter was used in `session.get("http://python.org", trust_env=True)` method instead of aiohttp.ClientSession constructor as follows: `aiohttp.ClientSession(trust_env=True)`. (`#2688 `_) - Fix issue with unittest example not compiling in testing docs. (`#2717 `_) Deprecations and Removals ------------------------- - Simplify HTTP pipelining implementation (`#2109 `_) - Drop `StreamReaderPayload` and `DataQueuePayload`. (`#2257 `_) - Drop `md5` and `sha1` finger-prints (`#2267 `_) - Drop WSMessage.tp (`#2321 `_) - Drop Python 3.4 and Python 3.5.0, 3.5.1, 3.5.2. Minimal supported Python versions are 3.5.3 and 3.6.0. `yield from` is gone, use `async/await` syntax. (`#2343 `_) - Drop `aiohttp.Timeout` and use `async_timeout.timeout` instead. (`#2348 `_) - Drop `resolve` param from TCPConnector. (`#2377 `_) - Add DeprecationWarning for returning HTTPException (`#2415 `_) - `send_str()`, `send_bytes()`, `send_json()`, `ping()` and `pong()` are genuine async functions now. (`#2475 `_) - Drop undocumented `app.on_pre_signal` and `app.on_post_signal`. Signal handlers should be coroutines, support for regular functions is dropped. (`#2480 `_) - `StreamResponse.drain()` is not a part of public API anymore, just use `await StreamResponse.write()`. `StreamResponse.write` is converted to async function. (`#2483 `_) - Drop deprecated `slow_request_timeout` param and `**kwargs`` from `RequestHandler`. (`#2500 `_) - Drop deprecated `resource.url()`. (`#2501 `_) - Remove `%u` and `%l` format specifiers from access log format. (`#2506 `_) - Drop deprecated `request.GET` property. (`#2547 `_) - Simplify stream classes: drop `ChunksQueue` and `FlowControlChunksQueue`, merge `FlowControlStreamReader` functionality into `StreamReader`, drop `FlowControlStreamReader` name. (`#2555 `_) - Do not create a new resource on `router.add_get(..., allow_head=True)` (`#2585 `_) - Drop access to TCP tuning options from PayloadWriter and Response classes (`#2604 `_) - Drop deprecated `encoding` parameter from client API (`#2606 `_) - Deprecate ``verify_ssl``, ``ssl_context`` and ``fingerprint`` parameters in client API (`#2626 `_) - Get rid of the legacy class StreamWriter. (`#2651 `_) - Forbid non-strings in `resource.url_for()` parameters. (`#2668 `_) - Deprecate inheritance from ``ClientSession`` and ``web.Application`` and custom user attributes for ``ClientSession``, ``web.Request`` and ``web.Application`` (`#2691 `_) - Drop `resp = await aiohttp.request(...)` syntax for sake of `async with aiohttp.request(...) as resp:`. (`#2540 `_) - Forbid synchronous context managers for `ClientSession` and test server/client. (`#2362 `_) Misc ---- - #2552 ---- 2.3.10 (2018-02-02) =================== - Fix 100% CPU usage on HTTP GET and websocket connection just after it (`#1955 `_) - Patch broken `ssl.match_hostname()` on Python<3.7 (`#2674 `_) ---- 2.3.9 (2018-01-16) ================== - Fix colon handing in path for dynamic resources (`#2670 `_) ---- 2.3.8 (2018-01-15) ================== - Do not use `yarl.unquote` internal function in aiohttp. Fix incorrectly unquoted path part in URL dispatcher (`#2662 `_) - Fix compatibility with `yarl==1.0.0` (`#2662 `_) ---- 2.3.7 (2017-12-27) ================== - Fixed race-condition for iterating addresses from the DNSCache. (`#2620 `_) - Fix docstring for request.host (`#2591 `_) - Fix docstring for request.remote (`#2592 `_) ---- 2.3.6 (2017-12-04) ================== - Correct `request.app` context (for handlers not just middlewares). (`#2577 `_) ---- 2.3.5 (2017-11-30) ================== - Fix compatibility with `pytest` 3.3+ (`#2565 `_) ---- 2.3.4 (2017-11-29) ================== - Make `request.app` point to proper application instance when using nested applications (with middlewares). (`#2550 `_) - Change base class of ClientConnectorSSLError to ClientSSLError from ClientConnectorError. (`#2563 `_) - Return client connection back to free pool on error in `connector.connect()`. (`#2567 `_) ---- 2.3.3 (2017-11-17) ================== - Having a `;` in Response content type does not assume it contains a charset anymore. (`#2197 `_) - Use `getattr(asyncio, 'async')` for keeping compatibility with Python 3.7. (`#2476 `_) - Ignore `NotImplementedError` raised by `set_child_watcher` from `uvloop`. (`#2491 `_) - Fix warning in `ClientSession.__del__` by stopping to try to close it. (`#2523 `_) - Fixed typo's in Third-party libraries page. And added async-v20 to the list (`#2510 `_) ---- 2.3.2 (2017-11-01) ================== - Fix passing client max size on cloning request obj. (`#2385 `_) - Fix ClientConnectorSSLError and ClientProxyConnectionError for proxy connector. (`#2408 `_) - Drop generated `_http_parser` shared object from tarball distribution. (`#2414 `_) - Fix connector convert OSError to ClientConnectorError. (`#2423 `_) - Fix connection attempts for multiple dns hosts. (`#2424 `_) - Fix ValueError for AF_INET6 sockets if a preexisting INET6 socket to the `aiohttp.web.run_app` function. (`#2431 `_) - `_SessionRequestContextManager` closes the session properly now. (`#2441 `_) - Rename `from_env` to `trust_env` in client reference. (`#2451 `_) ---- 2.3.1 (2017-10-18) ================== - Relax attribute lookup in warning about old-styled middleware (`#2340 `_) ---- 2.3.0 (2017-10-18) ================== Features -------- - Add SSL related params to `ClientSession.request` (`#1128 `_) - Make enable_compression work on HTTP/1.0 (`#1828 `_) - Deprecate registering synchronous web handlers (`#1993 `_) - Switch to `multidict 3.0`. All HTTP headers preserve casing now but compared in case-insensitive way. (`#1994 `_) - Improvement for `normalize_path_middleware`. Added possibility to handle URLs with query string. (`#1995 `_) - Use towncrier for CHANGES.txt build (`#1997 `_) - Implement `trust_env=True` param in `ClientSession`. (`#1998 `_) - Added variable to customize proxy headers (`#2001 `_) - Implement `router.add_routes` and router decorators. (`#2004 `_) - Deprecated `BaseRequest.has_body` in favor of `BaseRequest.can_read_body` Added `BaseRequest.body_exists` attribute that stays static for the lifetime of the request (`#2005 `_) - Provide `BaseRequest.loop` attribute (`#2024 `_) - Make `_CoroGuard` awaitable and fix `ClientSession.close` warning message (`#2026 `_) - Responses to redirects without Location header are returned instead of raising a RuntimeError (`#2030 `_) - Added `get_client`, `get_server`, `setUpAsync` and `tearDownAsync` methods to AioHTTPTestCase (`#2032 `_) - Add automatically a SafeChildWatcher to the test loop (`#2058 `_) - add ability to disable automatic response decompression (`#2110 `_) - Add support for throttling DNS request, avoiding the requests saturation when there is a miss in the DNS cache and many requests getting into the connector at the same time. (`#2111 `_) - Use request for getting access log information instead of message/transport pair. Add `RequestBase.remote` property for accessing to IP of client initiated HTTP request. (`#2123 `_) - json() raises a ContentTypeError exception if the content-type does not meet the requirements instead of raising a generic ClientResponseError. (`#2136 `_) - Make the HTTP client able to return HTTP chunks when chunked transfer encoding is used. (`#2150 `_) - add `append_version` arg into `StaticResource.url` and `StaticResource.url_for` methods for getting an url with hash (version) of the file. (`#2157 `_) - Fix parsing the Forwarded header. * commas and semicolons are allowed inside quoted-strings; * empty forwarded-pairs (as in for=_1;;by=_2) are allowed; * non-standard parameters are allowed (although this alone could be easily done in the previous parser). (`#2173 `_) - Don't require ssl module to run. aiohttp does not require SSL to function. The code paths involved with SSL will only be hit upon SSL usage. Raise `RuntimeError` if HTTPS protocol is required but ssl module is not present. (`#2221 `_) - Accept coroutine fixtures in pytest plugin (`#2223 `_) - Call `shutdown_asyncgens` before event loop closing on Python 3.6. (`#2227 `_) - Speed up Signals when there are no receivers (`#2229 `_) - Raise `InvalidURL` instead of `ValueError` on fetches with invalid URL. (`#2241 `_) - Move `DummyCookieJar` into `cookiejar.py` (`#2242 `_) - `run_app`: Make `print=None` disable printing (`#2260 `_) - Support `brotli` encoding (generic-purpose lossless compression algorithm) (`#2270 `_) - Add server support for WebSockets Per-Message Deflate. Add client option to add deflate compress header in WebSockets request header. If calling ClientSession.ws_connect() with `compress=15` the client will support deflate compress negotiation. (`#2273 `_) - Support `verify_ssl`, `fingerprint`, `ssl_context` and `proxy_headers` by `client.ws_connect`. (`#2292 `_) - Added `aiohttp.ClientConnectorSSLError` when connection fails due `ssl.SSLError` (`#2294 `_) - `aiohttp.web.Application.make_handler` support `access_log_class` (`#2315 `_) - Build HTTP parser extension in non-strict mode by default. (`#2332 `_) Bugfixes -------- - Clear auth information on redirecting to other domain (`#1699 `_) - Fix missing app.loop on startup hooks during tests (`#2060 `_) - Fix issue with synchronous session closing when using `ClientSession` as an asynchronous context manager. (`#2063 `_) - Fix issue with `CookieJar` incorrectly expiring cookies in some edge cases. (`#2084 `_) - Force use of IPv4 during test, this will make tests run in a Docker container (`#2104 `_) - Warnings about unawaited coroutines now correctly point to the user's code. (`#2106 `_) - Fix issue with `IndexError` being raised by the `StreamReader.iter_chunks()` generator. (`#2112 `_) - Support HTTP 308 Permanent redirect in client class. (`#2114 `_) - Fix `FileResponse` sending empty chunked body on 304. (`#2143 `_) - Do not add `Content-Length: 0` to GET/HEAD/TRACE/OPTIONS requests by default. (`#2167 `_) - Fix parsing the Forwarded header according to RFC 7239. (`#2170 `_) - Securely determining remote/scheme/host #2171 (`#2171 `_) - Fix header name parsing, if name is split into multiple lines (`#2183 `_) - Handle session close during connection, `KeyError: ` (`#2193 `_) - Fixes uncaught `TypeError` in `helpers.guess_filename` if `name` is not a string (`#2201 `_) - Raise OSError on async DNS lookup if resolved domain is an alias for another one, which does not have an A or CNAME record. (`#2231 `_) - Fix incorrect warning in `StreamReader`. (`#2251 `_) - Properly clone state of web request (`#2284 `_) - Fix C HTTP parser for cases when status line is split into different TCP packets. (`#2311 `_) - Fix `web.FileResponse` overriding user supplied Content-Type (`#2317 `_) Improved Documentation ---------------------- - Add a note about possible performance degradation in `await resp.text()` if charset was not provided by `Content-Type` HTTP header. Pass explicit encoding to solve it. (`#1811 `_) - Drop `disqus` widget from documentation pages. (`#2018 `_) - Add a graceful shutdown section to the client usage documentation. (`#2039 `_) - Document `connector_owner` parameter. (`#2072 `_) - Update the doc of web.Application (`#2081 `_) - Fix mistake about access log disabling. (`#2085 `_) - Add example usage of on_startup and on_shutdown signals by creating and disposing an aiopg connection engine. (`#2131 `_) - Document `encoded=True` for `yarl.URL`, it disables all yarl transformations. (`#2198 `_) - Document that all app's middleware factories are run for every request. (`#2225 `_) - Reflect the fact that default resolver is threaded one starting from aiohttp 1.1 (`#2228 `_) Deprecations and Removals ------------------------- - Drop deprecated `Server.finish_connections` (`#2006 `_) - Drop %O format from logging, use %b instead. Drop %e format from logging, environment variables are not supported anymore. (`#2123 `_) - Drop deprecated secure_proxy_ssl_header support (`#2171 `_) - Removed TimeService in favor of simple caching. TimeService also had a bug where it lost about 0.5 seconds per second. (`#2176 `_) - Drop unused response_factory from static files API (`#2290 `_) Misc ---- - #2013, #2014, #2048, #2094, #2149, #2187, #2214, #2225, #2243, #2248 ---- 2.2.5 (2017-08-03) ================== - Don't raise deprecation warning on `loop.run_until_complete(client.close())` (`#2065 `_) ---- 2.2.4 (2017-08-02) ================== - Fix issue with synchronous session closing when using ClientSession as an asynchronous context manager. (`#2063 `_) ---- 2.2.3 (2017-07-04) ================== - Fix `_CoroGuard` for python 3.4 ---- 2.2.2 (2017-07-03) ================== - Allow `await session.close()` along with `yield from session.close()` ---- 2.2.1 (2017-07-02) ================== - Relax `yarl` requirement to 0.11+ - Backport #2026: `session.close` *is* a coroutine (`#2029 `_) ---- 2.2.0 (2017-06-20) ================== - Add doc for add_head, update doc for add_get. (`#1944 `_) - Fixed consecutive calls for `Response.write_eof`. - Retain method attributes (e.g. :code:`__doc__`) when registering synchronous handlers for resources. (`#1953 `_) - Added signal TERM handling in `run_app` to gracefully exit (`#1932 `_) - Fix websocket issues caused by frame fragmentation. (`#1962 `_) - Raise RuntimeError is you try to set the Content Length and enable chunked encoding at the same time (`#1941 `_) - Small update for `unittest_run_loop` - Use CIMultiDict for ClientRequest.skip_auto_headers (`#1970 `_) - Fix wrong startup sequence: test server and `run_app()` are not raise `DeprecationWarning` now (`#1947 `_) - Make sure cleanup signal is sent if startup signal has been sent (`#1959 `_) - Fixed server keep-alive handler, could cause 100% cpu utilization (`#1955 `_) - Connection can be destroyed before response get processed if `await aiohttp.request(..)` is used (`#1981 `_) - MultipartReader does not work with -OO (`#1969 `_) - Fixed `ClientPayloadError` with blank `Content-Encoding` header (`#1931 `_) - Support `deflate` encoding implemented in `httpbin.org/deflate` (`#1918 `_) - Fix BadStatusLine caused by extra `CRLF` after `POST` data (`#1792 `_) - Keep a reference to `ClientSession` in response object (`#1985 `_) - Deprecate undocumented `app.on_loop_available` signal (`#1978 `_) ---- 2.1.0 (2017-05-26) ================== - Added support for experimental `async-tokio` event loop written in Rust https://github.com/PyO3/tokio - Write to transport ``\r\n`` before closing after keepalive timeout, otherwise client can not detect socket disconnection. (`#1883 `_) - Only call `loop.close` in `run_app` if the user did *not* supply a loop. Useful for allowing clients to specify their own cleanup before closing the asyncio loop if they wish to tightly control loop behavior - Content disposition with semicolon in filename (`#917 `_) - Added `request_info` to response object and `ClientResponseError`. (`#1733 `_) - Added `history` to `ClientResponseError`. (`#1741 `_) - Allow to disable redirect url re-quoting (`#1474 `_) - Handle RuntimeError from transport (`#1790 `_) - Dropped "%O" in access logger (`#1673 `_) - Added `args` and `kwargs` to `unittest_run_loop`. Useful with other decorators, for example `@patch`. (`#1803 `_) - Added `iter_chunks` to response.content object. (`#1805 `_) - Avoid creating TimerContext when there is no timeout to allow compatibility with Tornado. (`#1817 `_) (`#1180 `_) - Add `proxy_from_env` to `ClientRequest` to read from environment variables. (`#1791 `_) - Add DummyCookieJar helper. (`#1830 `_) - Fix assertion errors in Python 3.4 from noop helper. (`#1847 `_) - Do not unquote `+` in match_info values (`#1816 `_) - Use Forwarded, X-Forwarded-Scheme and X-Forwarded-Host for better scheme and host resolution. (`#1134 `_) - Fix sub-application middlewares resolution order (`#1853 `_) - Fix applications comparison (`#1866 `_) - Fix static location in index when prefix is used (`#1662 `_) - Make test server more reliable (`#1896 `_) - Extend list of web exceptions, add HTTPUnprocessableEntity, HTTPFailedDependency, HTTPInsufficientStorage status codes (`#1920 `_) ---- 2.0.7 (2017-04-12) ================== - Fix *pypi* distribution - Fix exception description (`#1807 `_) - Handle socket error in FileResponse (`#1773 `_) - Cancel websocket heartbeat on close (`#1793 `_) ---- 2.0.6 (2017-04-04) ================== - Keeping blank values for `request.post()` and `multipart.form()` (`#1765 `_) - TypeError in data_received of ResponseHandler (`#1770 `_) - Fix ``web.run_app`` not to bind to default host-port pair if only socket is passed (`#1786 `_) ---- 2.0.5 (2017-03-29) ================== - Memory leak with aiohttp.request (`#1756 `_) - Disable cleanup closed ssl transports by default. - Exception in request handling if the server responds before the body is sent (`#1761 `_) ---- 2.0.4 (2017-03-27) ================== - Memory leak with aiohttp.request (`#1756 `_) - Encoding is always UTF-8 in POST data (`#1750 `_) - Do not add "Content-Disposition" header by default (`#1755 `_) ---- 2.0.3 (2017-03-24) ================== - Call https website through proxy will cause error (`#1745 `_) - Fix exception on multipart/form-data post if content-type is not set (`#1743 `_) ---- 2.0.2 (2017-03-21) ================== - Fixed Application.on_loop_available signal (`#1739 `_) - Remove debug code ---- 2.0.1 (2017-03-21) ================== - Fix allow-head to include name on route (`#1737 `_) - Fixed AttributeError in WebSocketResponse.can_prepare (`#1736 `_) ---- 2.0.0 (2017-03-20) ================== - Added `json` to `ClientSession.request()` method (`#1726 `_) - Added session's `raise_for_status` parameter, automatically calls raise_for_status() on any request. (`#1724 `_) - `response.json()` raises `ClientResponseError` exception if response's content type does not match (`#1723 `_) - Cleanup timer and loop handle on any client exception. - Deprecate `loop` parameter for Application's constructor - Properly handle payload errors (`#1710 `_) - Added `ClientWebSocketResponse.get_extra_info()` (`#1717 `_) - It is not possible to combine Transfer-Encoding and chunked parameter, same for compress and Content-Encoding (`#1655 `_) - Connector's `limit` parameter indicates total concurrent connections. New `limit_per_host` added, indicates total connections per endpoint. (`#1601 `_) - Use url's `raw_host` for name resolution (`#1685 `_) - Change `ClientResponse.url` to `yarl.URL` instance (`#1654 `_) - Add max_size parameter to web.Request reading methods (`#1133 `_) - Web Request.post() stores data in temp files (`#1469 `_) - Add the `allow_head=True` keyword argument for `add_get` (`#1618 `_) - `run_app` and the Command Line Interface now support serving over Unix domain sockets for faster inter-process communication. - `run_app` now supports passing a preexisting socket object. This can be useful e.g. for socket-based activated applications, when binding of a socket is done by the parent process. - Implementation for Trailer headers parser is broken (`#1619 `_) - Fix FileResponse to not fall on bad request (range out of file size) - Fix FileResponse to correct stream video to Chromes - Deprecate public low-level api (`#1657 `_) - Deprecate `encoding` parameter for ClientSession.request() method - Dropped aiohttp.wsgi (`#1108 `_) - Dropped `version` from ClientSession.request() method - Dropped websocket version 76 support (`#1160 `_) - Dropped: `aiohttp.protocol.HttpPrefixParser` (`#1590 `_) - Dropped: Servers response's `.started`, `.start()` and `.can_start()` method (`#1591 `_) - Dropped: Adding `sub app` via `app.router.add_subapp()` is deprecated use `app.add_subapp()` instead (`#1592 `_) - Dropped: `Application.finish()` and `Application.register_on_finish()` (`#1602 `_) - Dropped: `web.Request.GET` and `web.Request.POST` - Dropped: aiohttp.get(), aiohttp.options(), aiohttp.head(), aiohttp.post(), aiohttp.put(), aiohttp.patch(), aiohttp.delete(), and aiohttp.ws_connect() (`#1593 `_) - Dropped: `aiohttp.web.WebSocketResponse.receive_msg()` (`#1605 `_) - Dropped: `ServerHttpProtocol.keep_alive_timeout` attribute and `keep-alive`, `keep_alive_on`, `timeout`, `log` constructor parameters (`#1606 `_) - Dropped: `TCPConnector's`` `.resolve`, `.resolved_hosts`, `.clear_resolved_hosts()` attributes and `resolve` constructor parameter (`#1607 `_) - Dropped `ProxyConnector` (`#1609 `_) ---- 1.3.5 (2017-03-16) ================== - Fixed None timeout support (`#1720 `_) ---- 1.3.4 (2017-03-14) ================== - Revert timeout handling in client request - Fix StreamResponse representation after eof - Fix file_sender to not fall on bad request (range out of file size) - Fix file_sender to correct stream video to Chromes - Fix NotImplementedError server exception (`#1703 `_) - Clearer error message for URL without a host name. (`#1691 `_) - Silence deprecation warning in __repr__ (`#1690 `_) - IDN + HTTPS = `ssl.CertificateError` (`#1685 `_) ---- 1.3.3 (2017-02-19) ================== - Fixed memory leak in time service (`#1656 `_) ---- 1.3.2 (2017-02-16) ================== - Awaiting on WebSocketResponse.send_* does not work (`#1645 `_) - Fix multiple calls to client ws_connect when using a shared header dict (`#1643 `_) - Make CookieJar.filter_cookies() accept plain string parameter. (`#1636 `_) ---- 1.3.1 (2017-02-09) ================== - Handle CLOSING in WebSocketResponse.__anext__ - Fixed AttributeError 'drain' for server websocket handler (`#1613 `_) ---- 1.3.0 (2017-02-08) ================== - Multipart writer validates the data on append instead of on a request send (`#920 `_) - Multipart reader accepts multipart messages with or without their epilogue to consistently handle valid and legacy behaviors (`#1526 `_) (`#1581 `_) - Separate read + connect + request timeouts # 1523 - Do not swallow Upgrade header (`#1587 `_) - Fix polls demo run application (`#1487 `_) - Ignore unknown 1XX status codes in client (`#1353 `_) - Fix sub-Multipart messages missing their headers on serialization (`#1525 `_) - Do not use readline when reading the content of a part in the multipart reader (`#1535 `_) - Add optional flag for quoting `FormData` fields (`#916 `_) - 416 Range Not Satisfiable if requested range end > file size (`#1588 `_) - Having a `:` or `@` in a route does not work (`#1552 `_) - Added `receive_timeout` timeout for websocket to receive complete message. (`#1325 `_) - Added `heartbeat` parameter for websocket to automatically send `ping` message. (`#1024 `_) (`#777 `_) - Remove `web.Application` dependency from `web.UrlDispatcher` (`#1510 `_) - Accepting back-pressure from slow websocket clients (`#1367 `_) - Do not pause transport during set_parser stage (`#1211 `_) - Lingering close does not terminate before timeout (`#1559 `_) - `setsockopt` may raise `OSError` exception if socket is closed already (`#1595 `_) - Lots of CancelledError when requests are interrupted (`#1565 `_) - Allow users to specify what should happen to decoding errors when calling a responses `text()` method (`#1542 `_) - Back port std module `http.cookies` for python3.4.2 (`#1566 `_) - Maintain url's fragment in client response (`#1314 `_) - Allow concurrently close WebSocket connection (`#754 `_) - Gzipped responses with empty body raises ContentEncodingError (`#609 `_) - Return 504 if request handle raises TimeoutError. - Refactor how we use keep-alive and close lingering timeouts. - Close response connection if we can not consume whole http message during client response release - Abort closed ssl client transports, broken servers can keep socket open un-limit time (`#1568 `_) - Log warning instead of `RuntimeError` is websocket connection is closed. - Deprecated: `aiohttp.protocol.HttpPrefixParser` will be removed in 1.4 (`#1590 `_) - Deprecated: Servers response's `.started`, `.start()` and `.can_start()` method will be removed in 1.4 (`#1591 `_) - Deprecated: Adding `sub app` via `app.router.add_subapp()` is deprecated use `app.add_subapp()` instead, will be removed in 1.4 (`#1592 `_) - Deprecated: aiohttp.get(), aiohttp.options(), aiohttp.head(), aiohttp.post(), aiohttp.put(), aiohttp.patch(), aiohttp.delete(), and aiohttp.ws_connect() will be removed in 1.4 (`#1593 `_) - Deprecated: `Application.finish()` and `Application.register_on_finish()` will be removed in 1.4 (`#1602 `_) ---- 1.2.0 (2016-12-17) ================== - Extract `BaseRequest` from `web.Request`, introduce `web.Server` (former `RequestHandlerFactory`), introduce new low-level web server which is not coupled with `web.Application` and routing (`#1362 `_) - Make `TestServer.make_url` compatible with `yarl.URL` (`#1389 `_) - Implement range requests for static files (`#1382 `_) - Support task attribute for StreamResponse (`#1410 `_) - Drop `TestClient.app` property, use `TestClient.server.app` instead (BACKWARD INCOMPATIBLE) - Drop `TestClient.handler` property, use `TestClient.server.handler` instead (BACKWARD INCOMPATIBLE) - `TestClient.server` property returns a test server instance, was `asyncio.AbstractServer` (BACKWARD INCOMPATIBLE) - Follow gunicorn's signal semantics in `Gunicorn[UVLoop]WebWorker` (`#1201 `_) - Call worker_int and worker_abort callbacks in `Gunicorn[UVLoop]WebWorker` (`#1202 `_) - Has functional tests for client proxy (`#1218 `_) - Fix bugs with client proxy target path and proxy host with port (`#1413 `_) - Fix bugs related to the use of unicode hostnames (`#1444 `_) - Preserve cookie quoting/escaping (`#1453 `_) - FileSender will send gzipped response if gzip version available (`#1426 `_) - Don't override `Content-Length` header in `web.Response` if no body was set (`#1400 `_) - Introduce `router.post_init()` for solving (`#1373 `_) - Fix raise error in case of multiple calls of `TimeServive.stop()` - Allow to raise web exceptions on router resolving stage (`#1460 `_) - Add a warning for session creation outside of coroutine (`#1468 `_) - Avoid a race when application might start accepting incoming requests but startup signals are not processed yet e98e8c6 - Raise a `RuntimeError` when trying to change the status of the HTTP response after the headers have been sent (`#1480 `_) - Fix bug with https proxy acquired cleanup (`#1340 `_) - Use UTF-8 as the default encoding for multipart text parts (`#1484 `_) ---- 1.1.6 (2016-11-28) ================== - Fix `BodyPartReader.read_chunk` bug about returns zero bytes before `EOF` (`#1428 `_) ---- 1.1.5 (2016-11-16) ================== - Fix static file serving in fallback mode (`#1401 `_) ---- 1.1.4 (2016-11-14) ================== - Make `TestServer.make_url` compatible with `yarl.URL` (`#1389 `_) - Generate informative exception on redirects from server which does not provide redirection headers (`#1396 `_) ---- 1.1.3 (2016-11-10) ================== - Support *root* resources for sub-applications (`#1379 `_) ---- 1.1.2 (2016-11-08) ================== - Allow starting variables with an underscore (`#1379 `_) - Properly process UNIX sockets by gunicorn worker (`#1375 `_) - Fix ordering for `FrozenList` - Don't propagate pre and post signals to sub-application (`#1377 `_) ---- 1.1.1 (2016-11-04) ================== - Fix documentation generation (`#1120 `_) ---- 1.1.0 (2016-11-03) ================== - Drop deprecated `WSClientDisconnectedError` (BACKWARD INCOMPATIBLE) - Use `yarl.URL` in client API. The change is 99% backward compatible but `ClientResponse.url` is an `yarl.URL` instance now. (`#1217 `_) - Close idle keep-alive connections on shutdown (`#1222 `_) - Modify regex in AccessLogger to accept underscore and numbers (`#1225 `_) - Use `yarl.URL` in web server API. `web.Request.rel_url` and `web.Request.url` are added. URLs and templates are percent-encoded now. (`#1224 `_) - Accept `yarl.URL` by server redirections (`#1278 `_) - Return `yarl.URL` by `.make_url()` testing utility (`#1279 `_) - Properly format IPv6 addresses by `aiohttp.web.run_app` (`#1139 `_) - Use `yarl.URL` by server API (`#1288 `_) * Introduce `resource.url_for()`, deprecate `resource.url()`. * Implement `StaticResource`. * Inherit `SystemRoute` from `AbstractRoute` * Drop old-style routes: `Route`, `PlainRoute`, `DynamicRoute`, `StaticRoute`, `ResourceAdapter`. - Revert `resp.url` back to `str`, introduce `resp.url_obj` (`#1292 `_) - Raise ValueError if BasicAuth login has a ":" character (`#1307 `_) - Fix bug when ClientRequest send payload file with opened as open('filename', 'r+b') (`#1306 `_) - Enhancement to AccessLogger (pass *extra* dict) (`#1303 `_) - Show more verbose message on import errors (`#1319 `_) - Added save and load functionality for `CookieJar` (`#1219 `_) - Added option on `StaticRoute` to follow symlinks (`#1299 `_) - Force encoding of `application/json` content type to utf-8 (`#1339 `_) - Fix invalid invocations of `errors.LineTooLong` (`#1335 `_) - Websockets: Stop `async for` iteration when connection is closed (`#1144 `_) - Ensure TestClient HTTP methods return a context manager (`#1318 `_) - Raise `ClientDisconnectedError` to `FlowControlStreamReader` read function if `ClientSession` object is closed by client when reading data. (`#1323 `_) - Document deployment without `Gunicorn` (`#1120 `_) - Add deprecation warning for MD5 and SHA1 digests when used for fingerprint of site certs in TCPConnector. (`#1186 `_) - Implement sub-applications (`#1301 `_) - Don't inherit `web.Request` from `dict` but implement `MutableMapping` protocol. - Implement frozen signals - Don't inherit `web.Application` from `dict` but implement `MutableMapping` protocol. - Support freezing for web applications - Accept access_log parameter in `web.run_app`, use `None` to disable logging - Don't flap `tcp_cork` and `tcp_nodelay` in regular request handling. `tcp_nodelay` is still enabled by default. - Improve performance of web server by removing premature computing of Content-Type if the value was set by `web.Response` constructor. While the patch boosts speed of trivial `web.Response(text='OK', content_type='text/plain)` very well please don't expect significant boost of your application -- a couple DB requests and business logic is still the main bottleneck. - Boost performance by adding a custom time service (`#1350 `_) - Extend `ClientResponse` with `content_type` and `charset` properties like in `web.Request`. (`#1349 `_) - Disable aiodns by default (`#559 `_) - Don't flap `tcp_cork` in client code, use TCP_NODELAY mode by default. - Implement `web.Request.clone()` (`#1361 `_) ---- 1.0.5 (2016-10-11) ================== - Fix StreamReader._read_nowait to return all available data up to the requested amount (`#1297 `_) ---- 1.0.4 (2016-09-22) ================== - Fix FlowControlStreamReader.read_nowait so that it checks whether the transport is paused (`#1206 `_) ---- 1.0.2 (2016-09-22) ================== - Make CookieJar compatible with 32-bit systems (`#1188 `_) - Add missing `WSMsgType` to `web_ws.__all__`, see (`#1200 `_) - Fix `CookieJar` ctor when called with `loop=None` (`#1203 `_) - Fix broken upper-casing in wsgi support (`#1197 `_) ---- 1.0.1 (2016-09-16) ================== - Restore `aiohttp.web.MsgType` alias for `aiohttp.WSMsgType` for sake of backward compatibility (`#1178 `_) - Tune alabaster schema. - Use `text/html` content type for displaying index pages by static file handler. - Fix `AssertionError` in static file handling (`#1177 `_) - Fix access log formats `%O` and `%b` for static file handling - Remove `debug` setting of GunicornWorker, use `app.debug` to control its debug-mode instead ---- 1.0.0 (2016-09-16) ================== - Change default size for client session's connection pool from unlimited to 20 (`#977 `_) - Add IE support for cookie deletion. (`#994 `_) - Remove deprecated `WebSocketResponse.wait_closed` method (BACKWARD INCOMPATIBLE) - Remove deprecated `force` parameter for `ClientResponse.close` method (BACKWARD INCOMPATIBLE) - Avoid using of mutable CIMultiDict kw param in make_mocked_request (`#997 `_) - Make WebSocketResponse.close a little bit faster by avoiding new task creating just for timeout measurement - Add `proxy` and `proxy_auth` params to `client.get()` and family, deprecate `ProxyConnector` (`#998 `_) - Add support for websocket send_json and receive_json, synchronize server and client API for websockets (`#984 `_) - Implement router shourtcuts for most useful HTTP methods, use `app.router.add_get()`, `app.router.add_post()` etc. instead of `app.router.add_route()` (`#986 `_) - Support SSL connections for gunicorn worker (`#1003 `_) - Move obsolete examples to legacy folder - Switch to multidict 2.0 and title-cased strings (`#1015 `_) - `{FOO}e` logger format is case-sensitive now - Fix logger report for unix socket 8e8469b - Rename aiohttp.websocket to aiohttp._ws_impl - Rename ``aiohttp.MsgType`` to ``aiohttp.WSMsgType`` - Introduce ``aiohttp.WSMessage`` officially - Rename Message -> WSMessage - Remove deprecated decode param from resp.read(decode=True) - Use 5min default client timeout (`#1028 `_) - Relax HTTP method validation in UrlDispatcher (`#1037 `_) - Pin minimal supported asyncio version to 3.4.2+ (`loop.is_close()` should be present) - Remove aiohttp.websocket module (BACKWARD INCOMPATIBLE) Please use high-level client and server approaches - Link header for 451 status code is mandatory - Fix test_client fixture to allow multiple clients per test (`#1072 `_) - make_mocked_request now accepts dict as headers (`#1073 `_) - Add Python 3.5.2/3.6+ compatibility patch for async generator protocol change (`#1082 `_) - Improvement test_client can accept instance object (`#1083 `_) - Simplify ServerHttpProtocol implementation (`#1060 `_) - Add a flag for optional showing directory index for static file handling (`#921 `_) - Define `web.Application.on_startup()` signal handler (`#1103 `_) - Drop ChunkedParser and LinesParser (`#1111 `_) - Call `Application.startup` in GunicornWebWorker (`#1105 `_) - Fix client handling hostnames with 63 bytes when a port is given in the url (`#1044 `_) - Implement proxy support for ClientSession.ws_connect (`#1025 `_) - Return named tuple from WebSocketResponse.can_prepare (`#1016 `_) - Fix access_log_format in `GunicornWebWorker` (`#1117 `_) - Setup Content-Type to application/octet-stream by default (`#1124 `_) - Deprecate debug parameter from app.make_handler(), use `Application(debug=True)` instead (`#1121 `_) - Remove fragment string in request path (`#846 `_) - Use aiodns.DNSResolver.gethostbyname() if available (`#1136 `_) - Fix static file sending on uvloop when sendfile is available (`#1093 `_) - Make prettier urls if query is empty dict (`#1143 `_) - Fix redirects for HEAD requests (`#1147 `_) - Default value for `StreamReader.read_nowait` is -1 from now (`#1150 `_) - `aiohttp.StreamReader` is not inherited from `asyncio.StreamReader` from now (BACKWARD INCOMPATIBLE) (`#1150 `_) - Streams documentation added (`#1150 `_) - Add `multipart` coroutine method for web Request object (`#1067 `_) - Publish ClientSession.loop property (`#1149 `_) - Fix static file with spaces (`#1140 `_) - Fix piling up asyncio loop by cookie expiration callbacks (`#1061 `_) - Drop `Timeout` class for sake of `async_timeout` external library. `aiohttp.Timeout` is an alias for `async_timeout.timeout` - `use_dns_cache` parameter of `aiohttp.TCPConnector` is `True` by default (BACKWARD INCOMPATIBLE) (`#1152 `_) - `aiohttp.TCPConnector` uses asynchronous DNS resolver if available by default (BACKWARD INCOMPATIBLE) (`#1152 `_) - Conform to RFC3986 - do not include url fragments in client requests (`#1174 `_) - Drop `ClientSession.cookies` (BACKWARD INCOMPATIBLE) (`#1173 `_) - Refactor `AbstractCookieJar` public API (BACKWARD INCOMPATIBLE) (`#1173 `_) - Fix clashing cookies with have the same name but belong to different domains (BACKWARD INCOMPATIBLE) (`#1125 `_) - Support binary Content-Transfer-Encoding (`#1169 `_) ---- 0.22.5 (08-02-2016) =================== - Pin miltidict version to >=1.2.2 ---- 0.22.3 (07-26-2016) =================== - Do not filter cookies if unsafe flag provided (`#1005 `_) ---- 0.22.2 (07-23-2016) =================== - Suppress CancelledError when Timeout raises TimeoutError (`#970 `_) - Don't expose `aiohttp.__version__` - Add unsafe parameter to CookieJar (`#968 `_) - Use unsafe cookie jar in test client tools - Expose aiohttp.CookieJar name ---- 0.22.1 (07-16-2016) =================== - Large cookie expiration/max-age does not break an event loop from now (fixes (`#967 `_)) ---- 0.22.0 (07-15-2016) =================== - Fix bug in serving static directory (`#803 `_) - Fix command line arg parsing (`#797 `_) - Fix a documentation chapter about cookie usage (`#790 `_) - Handle empty body with gzipped encoding (`#758 `_) - Support 451 Unavailable For Legal Reasons http status (`#697 `_) - Fix Cookie share example and few small typos in docs (`#817 `_) - UrlDispatcher.add_route with partial coroutine handler (`#814 `_) - Optional support for aiodns (`#728 `_) - Add ServiceRestart and TryAgainLater websocket close codes (`#828 `_) - Fix prompt message for `web.run_app` (`#832 `_) - Allow to pass None as a timeout value to disable timeout logic (`#834 `_) - Fix leak of connection slot during connection error (`#835 `_) - Gunicorn worker with uvloop support `aiohttp.worker.GunicornUVLoopWebWorker` (`#878 `_) - Don't send body in response to HEAD request (`#838 `_) - Skip the preamble in MultipartReader (`#881 `_) - Implement BasicAuth decode classmethod. (`#744 `_) - Don't crash logger when transport is None (`#889 `_) - Use a create_future compatibility wrapper instead of creating Futures directly (`#896 `_) - Add test utilities to aiohttp (`#902 `_) - Improve Request.__repr__ (`#875 `_) - Skip DNS resolving if provided host is already an ip address (`#874 `_) - Add headers to ClientSession.ws_connect (`#785 `_) - Document that server can send pre-compressed data (`#906 `_) - Don't add Content-Encoding and Transfer-Encoding if no body (`#891 `_) - Add json() convenience methods to websocket message objects (`#897 `_) - Add client_resp.raise_for_status() (`#908 `_) - Implement cookie filter (`#799 `_) - Include an example of middleware to handle error pages (`#909 `_) - Fix error handling in StaticFileMixin (`#856 `_) - Add mocked request helper (`#900 `_) - Fix empty ALLOW Response header for cls based View (`#929 `_) - Respect CONNECT method to implement a proxy server (`#847 `_) - Add pytest_plugin (`#914 `_) - Add tutorial - Add backlog option to support more than 128 (default value in "create_server" function) concurrent connections (`#892 `_) - Allow configuration of header size limits (`#912 `_) - Separate sending file logic from StaticRoute dispatcher (`#901 `_) - Drop deprecated share_cookies connector option (BACKWARD INCOMPATIBLE) - Drop deprecated support for tuple as auth parameter. Use aiohttp.BasicAuth instead (BACKWARD INCOMPATIBLE) - Remove deprecated `request.payload` property, use `content` instead. (BACKWARD INCOMPATIBLE) - Drop all mentions about api changes in documentation for versions older than 0.16 - Allow to override default cookie jar (`#963 `_) - Add manylinux wheel builds - Dup a socket for sendfile usage (`#964 `_) ---- 0.21.6 (05-05-2016) =================== - Drop initial query parameters on redirects (`#853 `_) ---- 0.21.5 (03-22-2016) =================== - Fix command line arg parsing (`#797 `_) ---- 0.21.4 (03-12-2016) =================== - Fix ResourceAdapter: don't add method to allowed if resource is not match (`#826 `_) - Fix Resource: append found method to returned allowed methods ---- 0.21.2 (02-16-2016) =================== - Fix a regression: support for handling ~/path in static file routes was broken (`#782 `_) ---- 0.21.1 (02-10-2016) =================== - Make new resources classes public (`#767 `_) - Add `router.resources()` view - Fix cmd-line parameter names in doc ---- 0.21.0 (02-04-2016) =================== - Introduce on_shutdown signal (`#722 `_) - Implement raw input headers (`#726 `_) - Implement web.run_app utility function (`#734 `_) - Introduce on_cleanup signal - Deprecate Application.finish() / Application.register_on_finish() in favor of on_cleanup. - Get rid of bare aiohttp.request(), aiohttp.get() and family in docs (`#729 `_) - Deprecate bare aiohttp.request(), aiohttp.get() and family (`#729 `_) - Refactor keep-alive support (`#737 `_) - Enable keepalive for HTTP 1.0 by default - Disable it for HTTP 0.9 (who cares about 0.9, BTW?) - For keepalived connections - Send `Connection: keep-alive` for HTTP 1.0 only - don't send `Connection` header for HTTP 1.1 - For non-keepalived connections - Send `Connection: close` for HTTP 1.1 only - don't send `Connection` header for HTTP 1.0 - Add version parameter to ClientSession constructor, deprecate it for session.request() and family (`#736 `_) - Enable access log by default (`#735 `_) - Deprecate app.router.register_route() (the method was not documented intentionally BTW). - Deprecate app.router.named_routes() in favor of app.router.named_resources() - route.add_static accepts pathlib.Path now (`#743 `_) - Add command line support: `$ python -m aiohttp.web package.main` (`#740 `_) - FAQ section was added to docs. Enjoy and fill free to contribute new topics - Add async context manager support to ClientSession - Document ClientResponse's host, method, url properties - Use CORK/NODELAY in client API (`#748 `_) - ClientSession.close and Connector.close are coroutines now - Close client connection on exception in ClientResponse.release() - Allow to read multipart parts without content-length specified (`#750 `_) - Add support for unix domain sockets to gunicorn worker (`#470 `_) - Add test for default Expect handler (`#601 `_) - Add the first demo project - Rename `loader` keyword argument in `web.Request.json` method. (`#646 `_) - Add local socket binding for TCPConnector (`#678 `_) ---- 0.20.2 (01-07-2016) =================== - Enable use of `await` for a class based view (`#717 `_) - Check address family to fill wsgi env properly (`#718 `_) - Fix memory leak in headers processing (thanks to Marco Paolini) (`#723 `_ ----) 0.20.1 (12-30-2015) =================== - Raise RuntimeError is Timeout context manager was used outside of task context. - Add number of bytes to stream.read_nowait (`#700 `_) - Use X-FORWARDED-PROTO for wsgi.url_scheme when available ---- 0.20.0 (12-28-2015) =================== - Extend list of web exceptions, add HTTPMisdirectedRequest, HTTPUpgradeRequired, HTTPPreconditionRequired, HTTPTooManyRequests, HTTPRequestHeaderFieldsTooLarge, HTTPVariantAlsoNegotiates, HTTPNotExtended, HTTPNetworkAuthenticationRequired status codes (`#644 `_) - Do not remove AUTHORIZATION header by WSGI handler (`#649 `_) - Fix broken support for https proxies with authentication (`#617 `_) - Get REMOTE_* and SEVER_* http vars from headers when listening on unix socket (`#654 `_) - Add HTTP 308 support (`#663 `_) - Add Tf format (time to serve request in seconds, %06f format) to access log (`#669 `_) - Remove one and a half years long deprecated ClientResponse.read_and_close() method - Optimize chunked encoding: use a single syscall instead of 3 calls on sending chunked encoded data - Use TCP_CORK and TCP_NODELAY to optimize network latency and throughput (`#680 `_) - Websocket XOR performance improved (`#687 `_) - Avoid sending cookie attributes in Cookie header (`#613 `_) - Round server timeouts to seconds for grouping pending calls. That leads to less amount of poller syscalls e.g. epoll.poll(). (`#702 `_) - Close connection on websocket handshake error (`#703 `_) - Implement class based views (`#684 `_) - Add *headers* parameter to ws_connect() (`#709 `_) - Drop unused function `parse_remote_addr()` (`#708 `_) - Close session on exception (`#707 `_) - Store http code and headers in WSServerHandshakeError (`#706 `_) - Make some low-level message properties readonly (`#710 `_) ---- 0.19.0 (11-25-2015) =================== - Memory leak in ParserBuffer (`#579 `_) - Support gunicorn's `max_requests` settings in gunicorn worker - Fix wsgi environment building (`#573 `_) - Improve access logging (`#572 `_) - Drop unused host and port from low-level server (`#586 `_) - Add Python 3.5 `async for` implementation to server websocket (`#543 `_) - Add Python 3.5 `async for` implementation to client websocket - Add Python 3.5 `async with` implementation to client websocket - Add charset parameter to web.Response constructor (`#593 `_) - Forbid passing both Content-Type header and content_type or charset params into web.Response constructor - Forbid duplicating of web.Application and web.Request (`#602 `_) - Add an option to pass Origin header in ws_connect (`#607 `_) - Add json_response function (`#592 `_) - Make concurrent connections respect limits (`#581 `_) - Collect history of responses if redirects occur (`#614 `_) - Enable passing pre-compressed data in requests (`#621 `_) - Expose named routes via UrlDispatcher.named_routes() (`#622 `_) - Allow disabling sendfile by environment variable AIOHTTP_NOSENDFILE (`#629 `_) - Use ensure_future if available - Always quote params for Content-Disposition (`#641 `_) - Support async for in multipart reader (`#640 `_) - Add Timeout context manager (`#611 `_) ---- 0.18.4 (13-11-2015) =================== - Relax rule for router names again by adding dash to allowed characters: they may contain identifiers, dashes, dots and columns ---- 0.18.3 (25-10-2015) =================== - Fix formatting for _RequestContextManager helper (`#590 `_) ---- 0.18.2 (22-10-2015) =================== - Fix regression for OpenSSL < 1.0.0 (`#583 `_) ---- 0.18.1 (20-10-2015) =================== - Relax rule for router names: they may contain dots and columns starting from now ---- 0.18.0 (19-10-2015) =================== - Use errors.HttpProcessingError.message as HTTP error reason and message (`#459 `_) - Optimize cythonized multidict a bit - Change repr's of multidicts and multidict views - default headers in ClientSession are now case-insensitive - Make '=' char and 'wss://' schema safe in urls (`#477 `_) - `ClientResponse.close()` forces connection closing by default from now (`#479 `_) N.B. Backward incompatible change: was `.close(force=False) Using `force` parameter for the method is deprecated: use `.release()` instead. - Properly requote URL's path (`#480 `_) - add `skip_auto_headers` parameter for client API (`#486 `_) - Properly parse URL path in aiohttp.web.Request (`#489 `_) - Raise RuntimeError when chunked enabled and HTTP is 1.0 (`#488 `_) - Fix a bug with processing io.BytesIO as data parameter for client API (`#500 `_) - Skip auto-generation of Content-Type header (`#507 `_) - Use sendfile facility for static file handling (`#503 `_) - Default `response_factory` in `app.router.add_static` now is `StreamResponse`, not `None`. The functionality is not changed if default is not specified. - Drop `ClientResponse.message` attribute, it was always implementation detail. - Streams are optimized for speed and mostly memory in case of a big HTTP message sizes (`#496 `_) - Fix a bug for server-side cookies for dropping cookie and setting it again without Max-Age parameter. - Don't trim redirect URL in client API (`#499 `_) - Extend precision of access log "D" to milliseconds (`#527 `_) - Deprecate `StreamResponse.start()` method in favor of `StreamResponse.prepare()` coroutine (`#525 `_) `.start()` is still supported but responses begun with `.start()` does not call signal for response preparing to be sent. - Add `StreamReader.__repr__` - Drop Python 3.3 support, from now minimal required version is Python 3.4.1 (`#541 `_) - Add `async with` support for `ClientSession.request()` and family (`#536 `_) - Ignore message body on 204 and 304 responses (`#505 `_) - `TCPConnector` processed both IPv4 and IPv6 by default (`#559 `_) - Add `.routes()` view for urldispatcher (`#519 `_) - Route name should be a valid identifier name from now (`#567 `_) - Implement server signals (`#562 `_) - Drop a year-old deprecated *files* parameter from client API. - Added `async for` support for aiohttp stream (`#542 `_) ---- 0.17.4 (09-29-2015) =================== - Properly parse URL path in aiohttp.web.Request (`#489 `_) - Add missing coroutine decorator, the client api is await-compatible now ---- 0.17.3 (08-28-2015) =================== - Remove Content-Length header on compressed responses (`#450 `_) - Support Python 3.5 - Improve performance of transport in-use list (`#472 `_) - Fix connection pooling (`#473 `_) ---- 0.17.2 (08-11-2015) =================== - Don't forget to pass `data` argument forward (`#462 `_) - Fix multipart read bytes count (`#463 `_) ---- 0.17.1 (08-10-2015) =================== - Fix multidict comparison to arbitrary abc.Mapping ---- 0.17.0 (08-04-2015) =================== - Make StaticRoute support Last-Modified and If-Modified-Since headers (`#386 `_) - Add Request.if_modified_since and Stream.Response.last_modified properties - Fix deflate compression when writing a chunked response (`#395 `_) - Request`s content-length header is cleared now after redirect from POST method (`#391 `_) - Return a 400 if server received a non HTTP content (`#405 `_) - Fix keep-alive support for aiohttp clients (`#406 `_) - Allow gzip compression in high-level server response interface (`#403 `_) - Rename TCPConnector.resolve and family to dns_cache (`#415 `_) - Make UrlDispatcher ignore quoted characters during url matching (`#414 `_) Backward-compatibility warning: this may change the url matched by your queries if they send quoted character (like %2F for /) (`#414 `_) - Use optional cchardet accelerator if present (`#418 `_) - Borrow loop from Connector in ClientSession if loop is not set - Add context manager support to ClientSession for session closing. - Add toplevel get(), post(), put(), head(), delete(), options(), patch() coroutines. - Fix IPv6 support for client API (`#425 `_) - Pass SSL context through proxy connector (`#421 `_) - Make the rule: path for add_route should start with slash - Don't process request finishing by low-level server on closed event loop - Don't override data if multiple files are uploaded with same key (`#433 `_) - Ensure multipart.BodyPartReader.read_chunk read all the necessary data to avoid false assertions about malformed multipart payload - Don't send body for 204, 205 and 304 http exceptions (`#442 `_) - Correctly skip Cython compilation in MSVC not found (`#453 `_) - Add response factory to StaticRoute (`#456 `_) - Don't append trailing CRLF for multipart.BodyPartReader (`#454 `_) ---- 0.16.6 (07-15-2015) =================== - Skip compilation on Windows if vcvarsall.bat cannot be found (`#438 `_) ---- 0.16.5 (06-13-2015) =================== - Get rid of all comprehensions and yielding in _multidict (`#410 `_) ---- 0.16.4 (06-13-2015) =================== - Don't clear current exception in multidict's `__repr__` (cythonized versions) (`#410 `_) ---- 0.16.3 (05-30-2015) =================== - Fix StaticRoute vulnerability to directory traversal attacks (`#380 `_) ---- 0.16.2 (05-27-2015) =================== - Update python version required for `__del__` usage: it's actually 3.4.1 instead of 3.4.0 - Add check for presence of loop.is_closed() method before call the former (`#378 `_) ---- 0.16.1 (05-27-2015) =================== - Fix regression in static file handling (`#377 `_) ---- 0.16.0 (05-26-2015) =================== - Unset waiter future after cancellation (`#363 `_) - Update request url with query parameters (`#372 `_) - Support new `fingerprint` param of TCPConnector to enable verifying SSL certificates via MD5, SHA1, or SHA256 digest (`#366 `_) - Setup uploaded filename if field value is binary and transfer encoding is not specified (`#349 `_) - Implement `ClientSession.close()` method - Implement `connector.closed` readonly property - Implement `ClientSession.closed` readonly property - Implement `ClientSession.connector` readonly property - Implement `ClientSession.detach` method - Add `__del__` to client-side objects: sessions, connectors, connections, requests, responses. - Refactor connections cleanup by connector (`#357 `_) - Add `limit` parameter to connector constructor (`#358 `_) - Add `request.has_body` property (`#364 `_) - Add `response_class` parameter to `ws_connect()` (`#367 `_) - `ProxyConnector` does not support keep-alive requests by default starting from now (`#368 `_) - Add `connector.force_close` property - Add ws_connect to ClientSession (`#374 `_) - Support optional `chunk_size` parameter in `router.add_static()` ---- 0.15.3 (04-22-2015) =================== - Fix graceful shutdown handling - Fix `Expect` header handling for not found and not allowed routes (`#340 `_) ---- 0.15.2 (04-19-2015) =================== - Flow control subsystem refactoring - HTTP server performance optimizations - Allow to match any request method with `*` - Explicitly call drain on transport (`#316 `_) - Make chardet module dependency mandatory (`#318 `_) - Support keep-alive for HTTP 1.0 (`#325 `_) - Do not chunk single file during upload (`#327 `_) - Add ClientSession object for cookie storage and default headers (`#328 `_) - Add `keep_alive_on` argument for HTTP server handler. ---- 0.15.1 (03-31-2015) =================== - Pass Autobahn Testsuite tests - Fixed websocket fragmentation - Fixed websocket close procedure - Fixed parser buffer limits - Added `timeout` parameter to WebSocketResponse ctor - Added `WebSocketResponse.close_code` attribute ---- 0.15.0 (03-27-2015) =================== - Client WebSockets support - New Multipart system (`#273 `_) - Support for "Except" header (`#287 `_) (`#267 `_) - Set default Content-Type for post requests (`#184 `_) - Fix issue with construction dynamic route with regexps and trailing slash (`#266 `_) - Add repr to web.Request - Add repr to web.Response - Add repr for NotFound and NotAllowed match infos - Add repr for web.Application - Add repr to UrlMappingMatchInfo (`#217 `_) - Gunicorn 19.2.x compatibility ---- 0.14.4 (01-29-2015) =================== - Fix issue with error during constructing of url with regex parts (`#264 `_) ---- 0.14.3 (01-28-2015) =================== - Use path='/' by default for cookies (`#261 `_) ---- 0.14.2 (01-23-2015) =================== - Connections leak in BaseConnector (`#253 `_) - Do not swallow websocket reader exceptions (`#255 `_) - web.Request's read, text, json are memorized (`#250 `_) ---- 0.14.1 (01-15-2015) =================== - HttpMessage._add_default_headers does not overwrite existing headers (`#216 `_) - Expose multidict classes at package level - add `aiohttp.web.WebSocketResponse` - According to RFC 6455 websocket subprotocol preference order is provided by client, not by server - websocket's ping and pong accept optional message parameter - multidict views do not accept `getall` parameter anymore, it returns the full body anyway. - multidicts have optional Cython optimization, cythonized version of multidicts is about 5 times faster than pure Python. - multidict.getall() returns `list`, not `tuple`. - Backward incompatible change: now there are two mutable multidicts (`MultiDict`, `CIMultiDict`) and two immutable multidict proxies (`MultiDictProxy` and `CIMultiDictProxy`). Previous edition of multidicts was not a part of public API BTW. - Router refactoring to push Not Allowed and Not Found in middleware processing - Convert `ConnectionError` to `aiohttp.DisconnectedError` and don't eat `ConnectionError` exceptions from web handlers. - Remove hop headers from Response class, wsgi response still uses hop headers. - Allow to send raw chunked encoded response. - Allow to encode output bytes stream into chunked encoding. - Allow to compress output bytes stream with `deflate` encoding. - Server has 75 seconds keepalive timeout now, was non-keepalive by default. - Application does not accept `**kwargs` anymore ((`#243 `_)). - Request is inherited from dict now for making per-request storage to middlewares ((`#242 `_)). ---- 0.13.1 (12-31-2014) =================== - Add `aiohttp.web.StreamResponse.started` property (`#213 `_) - HTML escape traceback text in `ServerHttpProtocol.handle_error` - Mention handler and middlewares in `aiohttp.web.RequestHandler.handle_request` on error ((`#218 `_)) ---- 0.13.0 (12-29-2014) =================== - `StreamResponse.charset` converts value to lower-case on assigning. - Chain exceptions when raise `ClientRequestError`. - Support custom regexps in route variables (`#204 `_) - Fixed graceful shutdown, disable keep-alive on connection closing. - Decode HTTP message with `utf-8` encoding, some servers send headers in utf-8 encoding (`#207 `_) - Support `aiohtt.web` middlewares (`#209 `_) - Add ssl_context to TCPConnector (`#206 `_) ---- 0.12.0 (12-12-2014) =================== - Deep refactoring of `aiohttp.web` in backward-incompatible manner. Sorry, we have to do this. - Automatically force aiohttp.web handlers to coroutines in `UrlDispatcher.add_route()` (`#186 `_) - Rename `Request.POST()` function to `Request.post()` - Added POST attribute - Response processing refactoring: constructor does not accept Request instance anymore. - Pass application instance to finish callback - Exceptions refactoring - Do not unquote query string in `aiohttp.web.Request` - Fix concurrent access to payload in `RequestHandle.handle_request()` - Add access logging to `aiohttp.web` - Gunicorn worker for `aiohttp.web` - Removed deprecated `AsyncGunicornWorker` - Removed deprecated HttpClient ---- 0.11.0 (11-29-2014) =================== - Support named routes in `aiohttp.web.UrlDispatcher` (`#179 `_) - Make websocket subprotocols conform to spec (`#181 `_) ---- 0.10.2 (11-19-2014) =================== - Don't unquote `environ['PATH_INFO']` in wsgi.py (`#177 `_) ---- 0.10.1 (11-17-2014) =================== - aiohttp.web.HTTPException and descendants now files response body with string like `404: NotFound` - Fix multidict `__iter__`, the method should iterate over keys, not (key, value) pairs. ---- 0.10.0 (11-13-2014) =================== - Add aiohttp.web subpackage for highlevel HTTP server support. - Add *reason* optional parameter to aiohttp.protocol.Response ctor. - Fix aiohttp.client bug for sending file without content-type. - Change error text for connection closed between server responses from 'Can not read status line' to explicit 'Connection closed by server' - Drop closed connections from connector (`#173 `_) - Set server.transport to None on .closing() (`#172 `_) ---- 0.9.3 (10-30-2014) ================== - Fix compatibility with asyncio 3.4.1+ (`#170 `_) ---- 0.9.2 (10-16-2014) ================== - Improve redirect handling (`#157 `_) - Send raw files as is (`#153 `_) - Better websocket support (`#150 `_) ---- 0.9.1 (08-30-2014) ================== - Added MultiDict support for client request params and data (`#114 `_). - Fixed parameter type for IncompleteRead exception (`#118 `_). - Strictly require ASCII headers names and values (`#137 `_) - Keep port in ProxyConnector (`#128 `_). - Python 3.4.1 compatibility (`#131 `_). ---- 0.9.0 (07-08-2014) ================== - Better client basic authentication support (`#112 `_). - Fixed incorrect line splitting in HttpRequestParser (`#97 `_). - Support StreamReader and DataQueue as request data. - Client files handling refactoring (`#20 `_). - Backward incompatible: Replace DataQueue with StreamReader for request payload (`#87 `_). ---- 0.8.4 (07-04-2014) ================== - Change ProxyConnector authorization parameters. ---- 0.8.3 (07-03-2014) ================== - Publish TCPConnector properties: verify_ssl, family, resolve, resolved_hosts. - Don't parse message body for HEAD responses. - Refactor client response decoding. ---- 0.8.2 (06-22-2014) ================== - Make ProxyConnector.proxy immutable property. - Make UnixConnector.path immutable property. - Fix resource leak for aiohttp.request() with implicit connector. - Rename Connector's reuse_timeout to keepalive_timeout. ---- 0.8.1 (06-18-2014) ================== - Use case insensitive multidict for server request/response headers. - MultiDict.getall() accepts default value. - Catch server ConnectionError. - Accept MultiDict (and derived) instances in aiohttp.request header argument. - Proxy 'CONNECT' support. ---- 0.8.0 (06-06-2014) ================== - Add support for utf-8 values in HTTP headers - Allow to use custom response class instead of HttpResponse - Use MultiDict for client request headers - Use MultiDict for server request/response headers - Store response headers in ClientResponse.headers attribute - Get rid of timeout parameter in aiohttp.client API - Exceptions refactoring ---- 0.7.3 (05-20-2014) ================== - Simple HTTP proxy support. ---- 0.7.2 (05-14-2014) ================== - Get rid of `__del__` methods - Use ResourceWarning instead of logging warning record. ---- 0.7.1 (04-28-2014) ================== - Do not unquote client request urls. - Allow multiple waiters on transport drain. - Do not return client connection to pool in case of exceptions. - Rename SocketConnector to TCPConnector and UnixSocketConnector to UnixConnector. ---- 0.7.0 (04-16-2014) ================== - Connection flow control. - HTTP client session/connection pool refactoring. - Better handling for bad server requests. ---- 0.6.5 (03-29-2014) ================== - Added client session reuse timeout. - Better client request cancellation support. - Better handling responses without content length. - Added HttpClient verify_ssl parameter support. ---- 0.6.4 (02-27-2014) ================== - Log content-length missing warning only for put and post requests. ---- 0.6.3 (02-27-2014) ================== - Better support for server exit. - Read response body until EOF if content-length is not defined (`#14 `_) ---- 0.6.2 (02-18-2014) ================== - Fix trailing char in allowed_methods. - Start slow request timer for first request. ---- 0.6.1 (02-17-2014) ================== - Added utility method HttpResponse.read_and_close() - Added slow request timeout. - Enable socket SO_KEEPALIVE if available. ---- 0.6.0 (02-12-2014) ================== - Better handling for process exit. ---- 0.5.0 (01-29-2014) ================== - Allow to use custom HttpRequest client class. - Use gunicorn keepalive setting for asynchronous worker. - Log leaking responses. - python 3.4 compatibility ---- 0.4.4 (11-15-2013) ================== - Resolve only AF_INET family, because it is not clear how to pass extra info to asyncio. ---- 0.4.3 (11-15-2013) ================== - Allow to wait completion of request with `HttpResponse.wait_for_close()` ---- 0.4.2 (11-14-2013) ================== - Handle exception in client request stream. - Prevent host resolving for each client request. ---- 0.4.1 (11-12-2013) ================== - Added client support for `expect: 100-continue` header. ---- 0.4 (11-06-2013) ================ - Added custom wsgi application close procedure - Fixed concurrent host failure in HttpClient ---- 0.3 (11-04-2013) ================ - Added PortMapperWorker - Added HttpClient - Added TCP connection timeout to HTTP client - Better client connection errors handling - Gracefully handle process exit ---- 0.2 === - Fix packaging ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at andrew.svetlov@gmail.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] [homepage]: http://contributor-covenant.org [version]: http://contributor-covenant.org/version/1/4/ ================================================ FILE: CONTRIBUTING.rst ================================================ Contributing ============ Instructions for contributors ----------------------------- In order to make a clone of the GitHub_ repo: open the link and press the "Fork" button on the upper-right menu of the web page. I hope everybody knows how to work with git and github nowadays :) Workflow is pretty straightforward: 1. Clone the GitHub_ repo using the ``--recurse-submodules`` argument 2. Setup your machine with the required development environment 3. Make a change 4. Make sure all tests passed 5. Add a file into the ``CHANGES`` folder, named after the ticket or PR number 6. Commit changes to your own aiohttp clone 7. Make a pull request from the github page of your clone against the master branch 8. Optionally make backport Pull Request(s) for landing a bug fix into released aiohttp versions. .. important:: Please open the "`contributing `_" documentation page to get detailed information about all steps. .. _GitHub: https://github.com/aio-libs/aiohttp ================================================ FILE: CONTRIBUTORS.txt ================================================ - Contributors - ---------------- A. Jesse Jiryu Davis Abdur Rehman Ali Adam Bannister Adam Cooper Adam Horacek Adam Mills Adrian Krupa Adrián Chaves Ahmed Tahri Alan Bogarin Alan Tse Alec Hanefeld Alejandro Gómez Aleksandr Danshyn Aleksey Kutepov Alex Hayes Alex Key Alex Khomchenko Alex Kuzmenko Alex Lisovoy Alexander Bayandin Alexander Karpinsky Alexander Koshevoy Alexander Malev Alexander Mohr Alexander Shorin Alexander Travov Alexandru Mihai Alexey Firsov Alexey Nikitin Alexey Popravka Alexey Stavrov Alexey Stepanov Almaz Salakhov Amin Etesamian Amit Tulshyan Amy Boyle Anas El Amraoui Anders Melchiorsen Andrei Ursulenko Andrej Antonov Andrew Leech Andrew Lytvyn Andrew Svetlov Andrew Top Andrew Zhou Andrii Soldatenko Anes Abismail Antoine Pietri Anton Kasyanov Anton Zhdan-Pushkin Arcadiy Ivanov Arie Bovenberg Arseny Timoniq Artem Yushkovskiy Arthur Darcet Austin Scola Bai Haoran Ben Bader Ben Beasley Ben Greiner Ben Kallus Ben Timby Benedikt Reinartz Bob Haddleton Boris Feld Borys Vorona Boyi Chen Brett Cannon Brett Higgins Brian Bouterse Brian C. Lane Brian Muller Bruce Merry Bruno Souza Cabral Bryan Kok Bryce Drennan Carl George Cecile Tonglet Chien-Wei Huang Chih-Yuan Chen Chris AtLee Chris Laws Chris Moore Chris Shucksmith Christophe Bornet Christopher Schmitt Claudiu Popa Colin Dunklau Cong Xu Damien Nadé Dan King Dan Xu Daniel García Daniel Golding Daniel Grossmann-Kavanagh Daniel Nelson Daniele Trifirò Danny Song David Bibb David Dzhalaev David Michael Brown Denilson Amorim Denis Matiychuk Denis Moshensky Dennis Kliban Devanshu Koyalkar Dima Veselov Dimitar Dimitrov Diogo Dutra da Mata Dmitriy Safonov Dmitry Doroshev Dmitry Erlikh Dmitry Lukashin Dmitry Marakasov Dmitry Shamov Dmitry Trofimov Dmytro Bohomiakov Dmytro Kuznetsov Dustin J. Mitchell Earle Lowe Eduard Iskandarov Eli Ribble Elizabeth Leddy Emil Melnikov Enrique Saez Eric Sheng Erich Healy Erik Peterson Eugene Chernyshov Eugene Ershov Eugene Naydenov Eugene Nikolaiev Eugene Tolmachev Evan Kepner Evert Lammerts Felix Yan Fernanda Guimarães FichteFoll Florian Scheffler Franek Magiera Frederik Gladhorn Frederik Peter Aalund Gabriel Tremblay Gang Ji Gary Leung Gary Wilson Jr. Gene Hoffman Gennady Andreyev Georges Dubus Greg Holt Gregory Haynes Grigoriy Soldatov Guillaume Leurquin Gus Goulart Gustavo Carneiro Günther Jena Hans Adema Harmon Y. Harry Liu Hiroshi Ogawa Hrishikesh Paranjape Hu Bo Hugh Young Hugo Herter Hugo Hromic Hugo van Kemenade Hynek Schlawack Igor Alexandrov Igor Bolshakov Igor Davydenko Igor Mozharovsky Igor Pavlov Illia Volochii Ilya Chichak Ilya Gruzinov Ingmar Steen Ivan Lakovic Ivan Larin J. Nick Koston Jacob Champion Jacob Henner Jaesung Lee Jake Davis Jakob Ackermann Jakub Wilk James Ward Jan Buchar Jan Gosmann Jarno Elonen Jashandeep Sohi Javier Torres Jean-Baptiste Estival Jens Steinhauser Jeonghun Lee Jeongkyu Shin Jeroen van der Heijden Jesus Cea Jian Zeng Jinkyu Yi Joel Watts John Feusi John Parton Jon Nabozny Jonas Krüger Svensson Jonas Obrist Jonathan Ballet Jonathan Wright Jonny Tan Joongi Kim Jordan Borean Josep Cugat Josh Junon Joshu Coats Julia Tsemusheva Julien Duponchelle Jungkook Park Junjie Tao Junyeong Jeong Justas Trimailovas Justin Foo Justin Turner Arthur Kay Zheng Kevin Samuel Kilian Guillaume Kimmo Parviainen-Jalanko Kirill Klenov Kirill Malovitsa Kirill Potapenko Konstantin Shutkin Konstantin Valetov Krzysztof Blazewicz Kyrylo Perevozchikov Kyungmin Lee Lars P. Søndergaard Lee LieWhite Liu Hua Louis-Philippe Huberdeau Loïc Lajeanne Lu Gong Lubomir Gelo Ludovic Gasc Luis Pedrosa Lukasz Marcin Dobrzanski Lénárd Szolnoki Makc Belousow Manuel Miranda Marat Sharafutdinov Marc Mueller Marco Paolini Marcus Stojcevich Mariano Anaya Mariusz Masztalerczuk Mark Larah Marko Kohtala Martijn Pieters Martin Melka Martin Richard Martin Sucha Mathias Fröjdman Mathieu Dugré Matt VanEseltine Matthew Go Matthias Marquardt Matthieu Hauglustaine Matthieu Rigal Matvey Tingaev Meet Mangukiya Meshya Michael Ihnatenko Michał Górny Mikhail Burshteyn Mikhail Kashkin Mikhail Lukyanchenko Mikhail Nacharov Mingjie Zhao Misha Behersky Mitchell Ferree Morgan Delahaye-Prat Moss Collum Mun Gwan-gyeong Navid Sheikhol Nicolas Braem Nikolay Kim Nikolay Novik Nikolay Tiunov Nándor Mátravölgyi Oisin Aylward Olaf Conradi Oleg Höfling Pahaz Blinov Panagiotis Kolokotronis Pankaj Pandey Parag Jain Parman Mohammadalizadeh Patrick Lee Pau Freixes Paul Colomiets Paul J. Dorn Paulius Šileikis Paulus Schoutsen Pavel Kamaev Pavel Polyakov Pavel Sapezhko Pavol Vargovčík Pawel Kowalski Pawel Miech Pepe Osca Phebe Polk Philipp A. Pierre-Louis Peeters Pieter van Beek Qiao Han Rafael Viotti Rahul Nahata Raphael Bialon Raúl Cumplido Required Field Robert Lu Robert Nikolich Rodrigo Nogueira Roman Markeloff Roman Podoliaka Roman Postnov Rong Zhang Rouven Bauer Samir Akarioh Samuel Colvin Samuel Gaist Sean Hunt Sebastian Acuna Sebastian Hanula Sebastian Hüther Sebastien Geffroy SeongSoo Cho Sergey Ninua Sergey Skripnick Serhii Charykov Serhii Kostel Serhiy Storchaka Shubh Agarwal Simon Kennedy Sin-Woo Bang Soheil Dolatabadi Stanislas Plum Stanislav Prokop Stefan Tjarks Stepan Pletnev Stephan Jaensch Stephen Cirelli Stephen Granade Steve Repsher Steven Seguin Sunghyun Hwang Sunit Deshpande Sviatoslav Bulbakha Sviatoslav Sydorenko Taha Jahangir Taras Voinarovskyi Terence Honles Thanos Lefteris Thijs Vermeir Thomas Forbes Thomas Grainger Tim Menninger Tolga Tezel Tom Whittock Tomasz Trebski Toshiaki Tanaka Trevor Gamblin Trinh Hoang Nhu Tymofii Tsiapa Vadim Suharnikov Vaibhav Sagar Vamsi Krishna Avula Vasiliy Faronov Vasyl Baran Viacheslav Greshilov Victor Collod Victor Kovtun Victor Makarov Vikas Kawadia Viktor Danyliuk Ville Skyttä Vincent Maillol Vitalik Verhovodov Vitaly Haritonsky Vitaly Magerya Vizonex Vladimir Kamarzin Vladimir Kozlovski Vladimir Rutsky Vladimir Shulyak Vladimir Vinogradenko Vladimir Zakharov Vladyslav Bohaichuk Vladyslav Bondar Vojtěch Boček W. Trevor King Wei Lin Weiwei Wang Will Fatherley Will McGugan Willem de Groot William Grzybowski William S. Wilson Ong wouter bolsterlee Xavier Halloran Xi Rui Xiang Li Yang Zhou Yannick Koechlin Yannick Péroux Ye Cao Yegor Roganov Yifei Kong Young-Ho Cha Yuriy Shatrov Yury Pliner Yury Selivanov Yusuke Tsutsumi Yuval Ofir Yuvi Panda Zainab Lawal Zeal Wierslee Zlatan Sičanica Łukasz Setla Марк Коренберг Семён Марьясин ================================================ FILE: LICENSE.txt ================================================ Copyright aio-libs contributors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ include LICENSE.txt include CHANGES.rst include README.rst include CONTRIBUTORS.txt include Makefile graft aiohttp graft docs graft examples graft tests graft tools graft requirements graft vendor global-exclude *.pyc global-exclude *.pyd global-exclude *.so global-exclude *.lib global-exclude *.dll global-exclude *.a global-exclude *.obj exclude aiohttp/*.html prune docs/_build ================================================ FILE: Makefile ================================================ # Some simple testing tasks (sorry, UNIX only). to-hash-one = $(dir $1).hash/$(addsuffix .hash,$(notdir $1)) to-hash = $(foreach fname,$1,$(call to-hash-one,$(fname))) CYS := $(wildcard aiohttp/*.pyx) $(wildcard aiohttp/*.pyi) $(wildcard aiohttp/*.pxd) $(wildcard aiohttp/_websocket/*.pyx) $(wildcard aiohttp/_websocket/*.pyi) $(wildcard aiohttp/_websocket/*.pxd) PYXS := $(wildcard aiohttp/*.pyx) $(wildcard aiohttp/_websocket/*.pyx) CS := $(wildcard aiohttp/*.c) $(wildcard aiohttp/_websocket/*.c) PYS := $(wildcard aiohttp/*.py) $(wildcard aiohttp/_websocket/*.py) IN := doc-spelling lint cython dev ALLS := $(sort $(CYS) $(CS) $(PYS) $(REQS)) .PHONY: all all: test tst: @echo $(call to-hash,requirements/cython.txt) @echo $(call to-hash,aiohttp/%.pyx) # Recipe from https://www.cmcrossroads.com/article/rebuilding-when-files-checksum-changes FORCE: # check_sum.py works perfectly fine but slow when called for every file from $(ALLS) # (perhaps even several times for each file). # That is why much less readable but faster solution exists ifneq (, $(shell command -v sha256sum)) %.hash: FORCE $(eval $@_ABS := $(abspath $@)) $(eval $@_NAME := $($@_ABS)) $(eval $@_HASHDIR := $(dir $($@_ABS))) $(eval $@_TMP := $($@_HASHDIR)../$(notdir $($@_ABS))) $(eval $@_ORIG := $(subst /.hash/../,/,$(basename $($@_TMP)))) @#echo ==== $($@_ABS) $($@_HASHDIR) $($@_NAME) $($@_TMP) $($@_ORIG) @if ! (sha256sum --check $($@_ABS) 1>/dev/null 2>/dev/null); then \ mkdir -p $($@_HASHDIR); \ echo re-hash $($@_ORIG); \ sha256sum $($@_ORIG) > $($@_ABS); \ fi else %.hash: FORCE @./tools/check_sum.py $@ # --debug endif # Enumerate intermediate files to don't remove them automatically. .SECONDARY: $(call to-hash,$(ALLS)) .update-pip: @python -m pip install --upgrade pip .install-cython: .update-pip $(call to-hash,requirements/cython.txt) @python -m pip install -r requirements/cython.in -c requirements/cython.txt @touch .install-cython aiohttp/_find_header.c: $(call to-hash,aiohttp/hdrs.py ./tools/gen.py) ./tools/gen.py # Special case for reader since we want to be able to disable # the extension with AIOHTTP_NO_EXTENSIONS aiohttp/_websocket/reader_c.c: aiohttp/_websocket/reader_c.py cython -3 -X freethreading_compatible=True -o $@ $< -I aiohttp -Werror # _find_headers generator creates _headers.pyi as well aiohttp/%.c: aiohttp/%.pyx $(call to-hash,$(CYS)) aiohttp/_find_header.c cython -3 -X freethreading_compatible=True -o $@ $< -I aiohttp -Werror aiohttp/_websocket/%.c: aiohttp/_websocket/%.pyx $(call to-hash,$(CYS)) cython -3 -X freethreading_compatible=True -o $@ $< -I aiohttp -Werror vendor/llhttp/node_modules: vendor/llhttp/package.json cd vendor/llhttp; npm ci .llhttp-gen: vendor/llhttp/node_modules $(MAKE) -C vendor/llhttp generate @touch .llhttp-gen .PHONY: generate-llhttp generate-llhttp: .llhttp-gen .PHONY: cythonize cythonize: .install-cython $(PYXS:.pyx=.c) aiohttp/_websocket/reader_c.c .PHONY: cythonize-nodeps cythonize-nodeps: $(PYXS:.pyx=.c) aiohttp/_websocket/reader_c.c .install-deps: .install-cython $(PYXS:.pyx=.c) aiohttp/_websocket/reader_c.c $(call to-hash,$(CYS) $(REQS)) @python -m pip install -r requirements/dev.in -c requirements/dev.txt @touch .install-deps .PHONY: lint lint: fmt mypy .PHONY: fmt format fmt format: python -m pre_commit run --all-files --show-diff-on-failure .PHONY: mypy mypy: mypy .develop: .install-deps generate-llhttp $(call to-hash,$(PYS) $(CYS) $(CS)) python -m pip install -e . -c requirements/runtime-deps.txt @touch .develop .PHONY: test test: .develop @pytest -q .PHONY: vtest vtest: .develop @pytest -s -v @python -X dev -m pytest --cov-append -s -v -m dev_mode .PHONY: vvtest vvtest: .develop @pytest -vv @python -X dev -m pytest --cov-append -s -vv -m dev_mode .PHONY: cov-dev cov-dev: .develop @pytest --cov-report=html @echo "xdg-open file://`pwd`/htmlcov/index.html" define run_tests_in_docker DOCKER_BUILDKIT=1 docker build --build-arg PYTHON_VERSION=$(1) --build-arg AIOHTTP_NO_EXTENSIONS=$(2) -t "aiohttp-test-$(1)-$(2)" -f tools/testing/Dockerfile . docker run --rm -ti -v `pwd`:/src -w /src "aiohttp-test-$(1)-$(2)" $(TEST_SPEC) endef .PHONY: clean clean: @rm -rf `find . -name __pycache__` @rm -rf `find . -name .hash` @rm -rf `find . -name .md5` # old styling @rm -f `find . -type f -name '*.py[co]' ` @rm -f `find . -type f -name '*~' ` @rm -f `find . -type f -name '.*~' ` @rm -f `find . -type f -name '@*' ` @rm -f `find . -type f -name '#*#' ` @rm -f `find . -type f -name '*.orig' ` @rm -f `find . -type f -name '*.rej' ` @rm -f `find . -type f -name '*.md5' ` # old styling @rm -f .coverage @rm -rf htmlcov @rm -rf build @rm -rf cover @make -C docs clean @python setup.py clean @rm -f aiohttp/*.so @rm -f aiohttp/*.pyd @rm -f aiohttp/*.html @rm -f aiohttp/_frozenlist.c @rm -f aiohttp/_find_header.c @rm -f aiohttp/_http_parser.c @rm -f aiohttp/_http_writer.c @rm -f aiohttp/_websocket.c @rm -f aiohttp/_websocket/reader_c.c @rm -rf .tox @rm -f .develop @rm -f .flake @rm -rf aiohttp.egg-info @rm -f .install-deps @rm -f .install-cython @rm -rf vendor/llhttp/node_modules @rm -f .llhttp-gen @$(MAKE) -C vendor/llhttp clean .PHONY: doc doc: @make -C docs html SPHINXOPTS="-W --keep-going -n -E" @echo "open file://`pwd`/docs/_build/html/index.html" .PHONY: doc-spelling doc-spelling: @make -C docs spelling SPHINXOPTS="-W --keep-going -n -E" .PHONY: install install: .update-pip @python -m pip install -r requirements/dev.in -c requirements/dev.txt .PHONY: install-dev install-dev: .develop .PHONY: sync-direct-runtime-deps sync-direct-runtime-deps: @echo Updating 'requirements/runtime-deps.in' from 'pyproject.toml'... >&2 @python requirements/sync-direct-runtime-deps.py ================================================ FILE: README.rst ================================================ ================================== Async http client/server framework ================================== .. image:: https://raw.githubusercontent.com/aio-libs/aiohttp/master/docs/aiohttp-plain.svg :height: 64px :width: 64px :alt: aiohttp logo | .. image:: https://github.com/aio-libs/aiohttp/workflows/CI/badge.svg :target: https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI :alt: GitHub Actions status for master branch .. image:: https://codecov.io/gh/aio-libs/aiohttp/branch/master/graph/badge.svg :target: https://codecov.io/gh/aio-libs/aiohttp :alt: codecov.io status for master branch .. image:: https://badge.fury.io/py/aiohttp.svg :target: https://pypi.org/project/aiohttp :alt: Latest PyPI package version .. image:: https://img.shields.io/pypi/dm/aiohttp :target: https://pypistats.org/packages/aiohttp :alt: Downloads count .. image:: https://readthedocs.org/projects/aiohttp/badge/?version=latest :target: https://docs.aiohttp.org/ :alt: Latest Read The Docs .. image:: https://img.shields.io/endpoint?url=https://codspeed.io/badge.json :target: https://codspeed.io/aio-libs/aiohttp :alt: Codspeed.io status for aiohttp Key Features ============ - Supports both client and server side of HTTP protocol. - Supports both client and server Web-Sockets out-of-the-box and avoids Callback Hell. - Provides Web-server with middleware and pluggable routing. Getting started =============== Client ------ To get something from the web: .. code-block:: python import aiohttp import asyncio async def main(): async with aiohttp.ClientSession() as session: async with session.get('http://python.org') as response: print("Status:", response.status) print("Content-type:", response.headers['content-type']) html = await response.text() print("Body:", html[:15], "...") asyncio.run(main()) This prints: .. code-block:: Status: 200 Content-type: text/html; charset=utf-8 Body: ... Coming from `requests `_ ? Read `why we need so many lines `_. Server ------ An example using a simple server: .. code-block:: python # examples/server_simple.py from aiohttp import web async def handle(request): name = request.match_info.get('name', "Anonymous") text = "Hello, " + name return web.Response(text=text) async def wshandle(request): ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: if msg.type == web.WSMsgType.text: await ws.send_str("Hello, {}".format(msg.data)) elif msg.type == web.WSMsgType.binary: await ws.send_bytes(msg.data) elif msg.type == web.WSMsgType.close: break return ws app = web.Application() app.add_routes([web.get('/', handle), web.get('/echo', wshandle), web.get('/{name}', handle)]) if __name__ == '__main__': web.run_app(app) Documentation ============= https://aiohttp.readthedocs.io/ Demos ===== https://github.com/aio-libs/aiohttp-demos External links ============== * `Third party libraries `_ * `Built with aiohttp `_ * `Powered by aiohttp `_ Feel free to make a Pull Request for adding your link to these pages! Communication channels ====================== *aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions *Matrix*: `#aio-libs:matrix.org `_ We support `Stack Overflow `_. Please add *aiohttp* tag to your question there. Requirements ============ - multidict_ - yarl_ - frozenlist_ Optionally you may install the aiodns_ library (highly recommended for sake of speed). .. _aiodns: https://pypi.python.org/pypi/aiodns .. _multidict: https://pypi.python.org/pypi/multidict .. _frozenlist: https://pypi.org/project/frozenlist/ .. _yarl: https://pypi.python.org/pypi/yarl License ======= ``aiohttp`` is offered under the Apache 2 license. Keepsafe ======== The aiohttp community would like to thank Keepsafe (https://www.getkeepsafe.com) for its support in the early days of the project. Source code =========== The latest developer version is available in a GitHub repository: https://github.com/aio-libs/aiohttp Benchmarks ========== If you are interested in efficiency, the AsyncIO community maintains a list of benchmarks on the official wiki: https://github.com/python/asyncio/wiki/Benchmarks -------- .. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat :target: https://matrix.to/#/%23aio-libs:matrix.org :alt: Matrix Room — #aio-libs:matrix.org .. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat :target: https://matrix.to/#/%23aio-libs-space:matrix.org :alt: Matrix Space — #aio-libs-space:matrix.org .. image:: https://insights.linuxfoundation.org/api/badge/health-score?project=aiohttp :target: https://insights.linuxfoundation.org/project/aiohttp :alt: LFX Health Score ================================================ FILE: aiohttp/__init__.py ================================================ __version__ = "4.0.0a2.dev0" from typing import TYPE_CHECKING from . import hdrs from .client import ( BaseConnector, ClientConnectionError, ClientConnectionResetError, ClientConnectorCertificateError, ClientConnectorDNSError, ClientConnectorError, ClientConnectorSSLError, ClientError, ClientHttpProxyError, ClientOSError, ClientPayloadError, ClientProxyConnectionError, ClientRequest, ClientResponse, ClientResponseError, ClientSession, ClientSSLError, ClientTimeout, ClientWebSocketResponse, ClientWSTimeout, ConnectionTimeoutError, ContentTypeError, Fingerprint, InvalidURL, InvalidUrlClientError, InvalidUrlRedirectClientError, NamedPipeConnector, NonHttpUrlClientError, NonHttpUrlRedirectClientError, RedirectClientError, RequestInfo, ServerConnectionError, ServerDisconnectedError, ServerFingerprintMismatch, ServerTimeoutError, SocketTimeoutError, TCPConnector, TooManyRedirects, UnixConnector, WSMessageTypeError, WSServerHandshakeError, request, ) from .client_middleware_digest_auth import DigestAuthMiddleware from .client_middlewares import ClientHandlerType, ClientMiddlewareType from .compression_utils import set_zlib_backend from .connector import AddrInfoType, SocketFactoryType from .cookiejar import CookieJar, DummyCookieJar from .formdata import FormData from .helpers import BasicAuth, ChainMapProxy, ETag from .http import ( HttpVersion, HttpVersion10, HttpVersion11, WebSocketError, WSCloseCode, WSMessage, WSMsgType, ) from .multipart import ( BadContentDispositionHeader, BadContentDispositionParam, BodyPartReader, MultipartReader, MultipartWriter, content_disposition_filename, parse_content_disposition, ) from .payload import ( PAYLOAD_REGISTRY, AsyncIterablePayload, BufferedReaderPayload, BytesIOPayload, BytesPayload, IOBasePayload, JsonPayload, Payload, StringIOPayload, StringPayload, TextIOPayload, get_payload, payload_type, ) from .resolver import AsyncResolver, DefaultResolver, ThreadedResolver from .streams import EMPTY_PAYLOAD, DataQueue, EofStream, StreamReader from .tracing import ( TraceConfig, TraceConnectionCreateEndParams, TraceConnectionCreateStartParams, TraceConnectionQueuedEndParams, TraceConnectionQueuedStartParams, TraceConnectionReuseconnParams, TraceDnsCacheHitParams, TraceDnsCacheMissParams, TraceDnsResolveHostEndParams, TraceDnsResolveHostStartParams, TraceRequestChunkSentParams, TraceRequestEndParams, TraceRequestExceptionParams, TraceRequestHeadersSentParams, TraceRequestRedirectParams, TraceRequestStartParams, TraceResponseChunkReceivedParams, ) if TYPE_CHECKING: # At runtime these are lazy-loaded at the bottom of the file. from .worker import GunicornUVLoopWebWorker, GunicornWebWorker __all__: tuple[str, ...] = ( "hdrs", # client "AddrInfoType", "BaseConnector", "ClientConnectionError", "ClientConnectionResetError", "ClientConnectorCertificateError", "ClientConnectorDNSError", "ClientConnectorError", "ClientConnectorSSLError", "ClientError", "ClientHttpProxyError", "ClientOSError", "ClientPayloadError", "ClientProxyConnectionError", "ClientResponse", "ClientRequest", "ClientResponseError", "ClientSSLError", "ClientSession", "ClientTimeout", "ClientWebSocketResponse", "ClientWSTimeout", "ConnectionTimeoutError", "ContentTypeError", "Fingerprint", "InvalidURL", "InvalidUrlClientError", "InvalidUrlRedirectClientError", "NonHttpUrlClientError", "NonHttpUrlRedirectClientError", "RedirectClientError", "RequestInfo", "ServerConnectionError", "ServerDisconnectedError", "ServerFingerprintMismatch", "ServerTimeoutError", "SocketFactoryType", "SocketTimeoutError", "TCPConnector", "TooManyRedirects", "UnixConnector", "NamedPipeConnector", "WSServerHandshakeError", "request", # client_middleware "ClientMiddlewareType", "ClientHandlerType", # cookiejar "CookieJar", "DummyCookieJar", # formdata "FormData", # helpers "BasicAuth", "ChainMapProxy", "DigestAuthMiddleware", "ETag", "set_zlib_backend", # http "HttpVersion", "HttpVersion10", "HttpVersion11", "WSMsgType", "WSCloseCode", "WSMessage", "WebSocketError", # multipart "BadContentDispositionHeader", "BadContentDispositionParam", "BodyPartReader", "MultipartReader", "MultipartWriter", "content_disposition_filename", "parse_content_disposition", # payload "AsyncIterablePayload", "BufferedReaderPayload", "BytesIOPayload", "BytesPayload", "IOBasePayload", "JsonPayload", "PAYLOAD_REGISTRY", "Payload", "StringIOPayload", "StringPayload", "TextIOPayload", "get_payload", "payload_type", # resolver "AsyncResolver", "DefaultResolver", "ThreadedResolver", # streams "DataQueue", "EMPTY_PAYLOAD", "EofStream", "StreamReader", # tracing "TraceConfig", "TraceConnectionCreateEndParams", "TraceConnectionCreateStartParams", "TraceConnectionQueuedEndParams", "TraceConnectionQueuedStartParams", "TraceConnectionReuseconnParams", "TraceDnsCacheHitParams", "TraceDnsCacheMissParams", "TraceDnsResolveHostEndParams", "TraceDnsResolveHostStartParams", "TraceRequestChunkSentParams", "TraceRequestEndParams", "TraceRequestExceptionParams", "TraceRequestHeadersSentParams", "TraceRequestRedirectParams", "TraceRequestStartParams", "TraceResponseChunkReceivedParams", # workers (imported lazily with __getattr__) "GunicornUVLoopWebWorker", "GunicornWebWorker", "WSMessageTypeError", ) def __dir__() -> tuple[str, ...]: return __all__ + ("__doc__",) def __getattr__(name: str) -> object: global GunicornUVLoopWebWorker, GunicornWebWorker # Importing gunicorn takes a long time (>100ms), so only import if actually needed. if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"): try: from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw except ImportError: return None GunicornUVLoopWebWorker = guv # type: ignore[misc] GunicornWebWorker = gw # type: ignore[misc] return guv if name == "GunicornUVLoopWebWorker" else gw raise AttributeError(f"module {__name__} has no attribute {name}") ================================================ FILE: aiohttp/_cookie_helpers.py ================================================ """ Internal cookie handling helpers. This module contains internal utilities for cookie parsing and manipulation. These are not part of the public API and may change without notice. """ import re from collections.abc import Sequence from http.cookies import Morsel from typing import cast from .log import internal_logger __all__ = ( "parse_set_cookie_headers", "parse_cookie_header", "preserve_morsel_with_coded_value", ) # Cookie parsing constants # Allow more characters in cookie names to handle real-world cookies # that don't strictly follow RFC standards (fixes #2683) # RFC 6265 defines cookie-name token as per RFC 2616 Section 2.2, # but many servers send cookies with characters like {} [] () etc. # This makes the cookie parser more tolerant of real-world cookies # while still providing some validation to catch obviously malformed names. _COOKIE_NAME_RE = re.compile(r"^[!#$%&\'()*+\-./0-9:<=>?@A-Z\[\]^_`a-z{|}~]+$") _COOKIE_KNOWN_ATTRS = frozenset( # AKA Morsel._reserved ( "path", "domain", "max-age", "expires", "secure", "httponly", "samesite", "partitioned", "version", "comment", ) ) _COOKIE_BOOL_ATTRS = frozenset( # AKA Morsel._flags ("secure", "httponly", "partitioned") ) # SimpleCookie's pattern for parsing cookies with relaxed validation # Based on http.cookies pattern but extended to allow more characters in cookie names # to handle real-world cookies (fixes #2683) _COOKIE_PATTERN = re.compile( r""" \s* # Optional whitespace at start of cookie (?P # Start of group 'key' # aiohttp has extended to include [] for compatibility with real-world cookies [\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\[\]]+ # Any word of at least one letter ) # End of group 'key' ( # Optional group: there may not be a value. \s*=\s* # Equal Sign (?P # Start of group 'val' "(?:[^\\"]|\\.)*" # Any double-quoted string (properly closed) | # or "[^";]* # Unmatched opening quote (differs from SimpleCookie - issue #7993) | # or # Special case for "expires" attr - RFC 822, RFC 850, RFC 1036, RFC 1123 (\w{3,6}day|\w{3}),\s # Day of the week or abbreviated day (with comma) [\w\d\s-]{9,11}\s[\d:]{8}\s # Date and time in specific format (GMT|[+-]\d{4}) # Timezone: GMT or RFC 2822 offset like -0000, +0100 # NOTE: RFC 2822 timezone support is an aiohttp extension # for issue #4493 - SimpleCookie does NOT support this | # or # ANSI C asctime() format: "Wed Jun 9 10:18:14 2021" # NOTE: This is an aiohttp extension for issue #4327 - SimpleCookie does NOT support this format \w{3}\s+\w{3}\s+[\s\d]\d\s+\d{2}:\d{2}:\d{2}\s+\d{4} | # or [\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=\[\]]* # Any word or empty string ) # End of group 'val' )? # End of optional value group \s* # Any number of spaces. (\s+|;|$) # Ending either at space, semicolon, or EOS. """, re.VERBOSE | re.ASCII, ) def preserve_morsel_with_coded_value(cookie: Morsel[str]) -> Morsel[str]: """ Preserve a Morsel's coded_value exactly as received from the server. This function ensures that cookie encoding is preserved exactly as sent by the server, which is critical for compatibility with old servers that have strict requirements about cookie formats. This addresses the issue described in https://github.com/aio-libs/aiohttp/pull/1453 where Python's SimpleCookie would re-encode cookies, breaking authentication with certain servers. Args: cookie: A Morsel object from SimpleCookie Returns: A Morsel object with preserved coded_value """ mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) # We use __setstate__ instead of the public set() API because it allows us to # bypass validation and set already validated state. This is more stable than # setting protected attributes directly and unlikely to change since it would # break pickling. mrsl_val.__setstate__( # type: ignore[attr-defined] {"key": cookie.key, "value": cookie.value, "coded_value": cookie.coded_value} ) return mrsl_val _unquote_sub = re.compile(r"\\(?:([0-3][0-7][0-7])|(.))").sub def _unquote_replace(m: re.Match[str]) -> str: """ Replace function for _unquote_sub regex substitution. Handles escaped characters in cookie values: - Octal sequences are converted to their character representation - Other escaped characters are unescaped by removing the backslash """ if m[1]: return chr(int(m[1], 8)) return m[2] def _unquote(value: str) -> str: """ Unquote a cookie value. Vendored from http.cookies._unquote to ensure compatibility. Note: The original implementation checked for None, but we've removed that check since all callers already ensure the value is not None. """ # If there aren't any doublequotes, # then there can't be any special characters. See RFC 2109. if len(value) < 2: return value if value[0] != '"' or value[-1] != '"': return value # We have to assume that we must decode this string. # Down to work. # Remove the "s value = value[1:-1] # Check for special sequences. Examples: # \012 --> \n # \" --> " # return _unquote_sub(_unquote_replace, value) def parse_cookie_header(header: str) -> list[tuple[str, Morsel[str]]]: """ Parse a Cookie header according to RFC 6265 Section 5.4. Cookie headers contain only name-value pairs separated by semicolons. There are no attributes in Cookie headers - even names that match attribute names (like 'path' or 'secure') should be treated as cookies. This parser uses the same regex-based approach as parse_set_cookie_headers to properly handle quoted values that may contain semicolons. When the regex fails to match a malformed cookie, it falls back to simple parsing to ensure subsequent cookies are not lost https://github.com/aio-libs/aiohttp/issues/11632 Args: header: The Cookie header value to parse Returns: List of (name, Morsel) tuples for compatibility with SimpleCookie.update() """ if not header: return [] morsel: Morsel[str] cookies: list[tuple[str, Morsel[str]]] = [] i = 0 n = len(header) invalid_names = [] while i < n: # Use the same pattern as parse_set_cookie_headers to find cookies match = _COOKIE_PATTERN.match(header, i) if not match: # Fallback for malformed cookies https://github.com/aio-libs/aiohttp/issues/11632 # Find next semicolon to skip or attempt simple key=value parsing next_semi = header.find(";", i) eq_pos = header.find("=", i) # Try to extract key=value if '=' comes before ';' if eq_pos != -1 and (next_semi == -1 or eq_pos < next_semi): end_pos = next_semi if next_semi != -1 else n key = header[i:eq_pos].strip() value = header[eq_pos + 1 : end_pos].strip() # Validate the name (same as regex path) if not _COOKIE_NAME_RE.match(key): invalid_names.append(key) else: morsel = Morsel() morsel.__setstate__( # type: ignore[attr-defined] {"key": key, "value": _unquote(value), "coded_value": value} ) cookies.append((key, morsel)) # Move to next cookie or end i = next_semi + 1 if next_semi != -1 else n continue key = match.group("key") value = match.group("val") or "" i = match.end(0) # Validate the name if not key or not _COOKIE_NAME_RE.match(key): invalid_names.append(key) continue # Create new morsel morsel = Morsel() # Preserve the original value as coded_value (with quotes if present) # We use __setstate__ instead of the public set() API because it allows us to # bypass validation and set already validated state. This is more stable than # setting protected attributes directly and unlikely to change since it would # break pickling. morsel.__setstate__( # type: ignore[attr-defined] {"key": key, "value": _unquote(value), "coded_value": value} ) cookies.append((key, morsel)) if invalid_names: internal_logger.debug( "Cannot load cookie. Illegal cookie names: %r", invalid_names ) return cookies def parse_set_cookie_headers(headers: Sequence[str]) -> list[tuple[str, Morsel[str]]]: """ Parse cookie headers using a vendored version of SimpleCookie parsing. This implementation is based on SimpleCookie.__parse_string to ensure compatibility with how SimpleCookie parses cookies, including handling of malformed cookies with missing semicolons. This function is used for both Cookie and Set-Cookie headers in order to be forgiving. Ideally we would have followed RFC 6265 Section 5.2 (for Cookie headers) and RFC 6265 Section 4.2.1 (for Set-Cookie headers), but the real world data makes it impossible since we need to be a bit more forgiving. NOTE: This implementation differs from SimpleCookie in handling unmatched quotes. SimpleCookie will stop parsing when it encounters a cookie value with an unmatched quote (e.g., 'cookie="value'), causing subsequent cookies to be silently dropped. This implementation handles unmatched quotes more gracefully to prevent cookie loss. See https://github.com/aio-libs/aiohttp/issues/7993 """ parsed_cookies: list[tuple[str, Morsel[str]]] = [] for header in headers: if not header: continue # Parse cookie string using SimpleCookie's algorithm i = 0 n = len(header) current_morsel: Morsel[str] | None = None morsel_seen = False while 0 <= i < n: # Start looking for a cookie match = _COOKIE_PATTERN.match(header, i) if not match: # No more cookies break key, value = match.group("key"), match.group("val") i = match.end(0) lower_key = key.lower() if key[0] == "$": if not morsel_seen: # We ignore attributes which pertain to the cookie # mechanism as a whole, such as "$Version". continue # Process as attribute if current_morsel is not None: attr_lower_key = lower_key[1:] if attr_lower_key in _COOKIE_KNOWN_ATTRS: current_morsel[attr_lower_key] = value or "" elif lower_key in _COOKIE_KNOWN_ATTRS: if not morsel_seen: # Invalid cookie string - attribute before cookie break if lower_key in _COOKIE_BOOL_ATTRS: # Boolean attribute with any value should be True if current_morsel is not None and current_morsel.isReservedKey(key): current_morsel[lower_key] = True elif value is None: # Invalid cookie string - non-boolean attribute without value break elif current_morsel is not None: # Regular attribute with value current_morsel[lower_key] = _unquote(value) elif value is not None: # This is a cookie name=value pair # Validate the name if key in _COOKIE_KNOWN_ATTRS or not _COOKIE_NAME_RE.match(key): internal_logger.warning( "Can not load cookies: Illegal cookie name %r", key ) current_morsel = None else: # Create new morsel current_morsel = Morsel() # Preserve the original value as coded_value (with quotes if present) # We use __setstate__ instead of the public set() API because it allows us to # bypass validation and set already validated state. This is more stable than # setting protected attributes directly and unlikely to change since it would # break pickling. current_morsel.__setstate__( # type: ignore[attr-defined] {"key": key, "value": _unquote(value), "coded_value": value} ) parsed_cookies.append((key, current_morsel)) morsel_seen = True else: # Invalid cookie string - no value for non-attribute break return parsed_cookies ================================================ FILE: aiohttp/_cparser.pxd ================================================ from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t cdef extern from "llhttp.h": struct llhttp__internal_s: int32_t _index void* _span_pos0 void* _span_cb0 int32_t error const char* reason const char* error_pos void* data void* _current uint64_t content_length uint8_t type uint8_t method uint8_t http_major uint8_t http_minor uint8_t header_state uint8_t lenient_flags uint8_t upgrade uint8_t finish uint16_t flags uint16_t status_code void* settings ctypedef llhttp__internal_s llhttp__internal_t ctypedef llhttp__internal_t llhttp_t ctypedef int (*llhttp_data_cb)(llhttp_t*, const char *at, size_t length) except -1 ctypedef int (*llhttp_cb)(llhttp_t*) except -1 struct llhttp_settings_s: llhttp_cb on_message_begin llhttp_data_cb on_url llhttp_data_cb on_status llhttp_data_cb on_header_field llhttp_data_cb on_header_value llhttp_cb on_headers_complete llhttp_data_cb on_body llhttp_cb on_message_complete llhttp_cb on_chunk_header llhttp_cb on_chunk_complete llhttp_cb on_url_complete llhttp_cb on_status_complete llhttp_cb on_header_field_complete llhttp_cb on_header_value_complete ctypedef llhttp_settings_s llhttp_settings_t enum llhttp_errno: HPE_OK, HPE_INTERNAL, HPE_STRICT, HPE_LF_EXPECTED, HPE_UNEXPECTED_CONTENT_LENGTH, HPE_CLOSED_CONNECTION, HPE_INVALID_METHOD, HPE_INVALID_URL, HPE_INVALID_CONSTANT, HPE_INVALID_VERSION, HPE_INVALID_HEADER_TOKEN, HPE_INVALID_CONTENT_LENGTH, HPE_INVALID_CHUNK_SIZE, HPE_INVALID_STATUS, HPE_INVALID_EOF_STATE, HPE_INVALID_TRANSFER_ENCODING, HPE_CB_MESSAGE_BEGIN, HPE_CB_HEADERS_COMPLETE, HPE_CB_MESSAGE_COMPLETE, HPE_CB_CHUNK_HEADER, HPE_CB_CHUNK_COMPLETE, HPE_PAUSED, HPE_PAUSED_UPGRADE, HPE_USER ctypedef llhttp_errno llhttp_errno_t enum llhttp_flags: F_CHUNKED, F_CONTENT_LENGTH enum llhttp_type: HTTP_REQUEST, HTTP_RESPONSE, HTTP_BOTH enum llhttp_method: HTTP_DELETE, HTTP_GET, HTTP_HEAD, HTTP_POST, HTTP_PUT, HTTP_CONNECT, HTTP_OPTIONS, HTTP_TRACE, HTTP_COPY, HTTP_LOCK, HTTP_MKCOL, HTTP_MOVE, HTTP_PROPFIND, HTTP_PROPPATCH, HTTP_SEARCH, HTTP_UNLOCK, HTTP_BIND, HTTP_REBIND, HTTP_UNBIND, HTTP_ACL, HTTP_REPORT, HTTP_MKACTIVITY, HTTP_CHECKOUT, HTTP_MERGE, HTTP_MSEARCH, HTTP_NOTIFY, HTTP_SUBSCRIBE, HTTP_UNSUBSCRIBE, HTTP_PATCH, HTTP_PURGE, HTTP_MKCALENDAR, HTTP_LINK, HTTP_UNLINK, HTTP_SOURCE, HTTP_PRI, HTTP_DESCRIBE, HTTP_ANNOUNCE, HTTP_SETUP, HTTP_PLAY, HTTP_PAUSE, HTTP_TEARDOWN, HTTP_GET_PARAMETER, HTTP_SET_PARAMETER, HTTP_REDIRECT, HTTP_RECORD, HTTP_FLUSH ctypedef llhttp_method llhttp_method_t; void llhttp_settings_init(llhttp_settings_t* settings) void llhttp_init(llhttp_t* parser, llhttp_type type, const llhttp_settings_t* settings) llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len) int llhttp_should_keep_alive(const llhttp_t* parser) void llhttp_resume_after_upgrade(llhttp_t* parser) llhttp_errno_t llhttp_get_errno(const llhttp_t* parser) const char* llhttp_get_error_reason(const llhttp_t* parser) const char* llhttp_get_error_pos(const llhttp_t* parser) const char* llhttp_method_name(llhttp_method_t method) void llhttp_set_lenient_headers(llhttp_t* parser, int enabled) void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled) void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled) ================================================ FILE: aiohttp/_find_header.h ================================================ #ifndef _FIND_HEADERS_H #define _FIND_HEADERS_H #ifdef __cplusplus extern "C" { #endif int find_header(const char *str, int size); #ifdef __cplusplus } #endif #endif ================================================ FILE: aiohttp/_find_header.pxd ================================================ cdef extern from "_find_header.h": int find_header(char *, int) ================================================ FILE: aiohttp/_http_parser.pyx ================================================ # Based on https://github.com/MagicStack/httptools # from cpython cimport ( Py_buffer, PyBUF_SIMPLE, PyBuffer_Release, PyBytes_AsString, PyBytes_AsStringAndSize, PyObject_GetBuffer, ) from cpython.mem cimport PyMem_Free, PyMem_Malloc from libc.limits cimport ULLONG_MAX from libc.string cimport memcpy from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy from yarl import URL as _URL from aiohttp import hdrs from aiohttp.helpers import DEBUG, set_exception from .http_exceptions import ( BadHttpMessage, BadHttpMethod, BadStatusLine, ContentLengthError, InvalidHeader, InvalidURLError, LineTooLong, PayloadEncodingError, TransferEncodingError, ) from .http_parser import DeflateBuffer as _DeflateBuffer from .http_writer import ( HttpVersion as _HttpVersion, HttpVersion10 as _HttpVersion10, HttpVersion11 as _HttpVersion11, ) from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader cimport cython from aiohttp cimport _cparser as cparser include "_headers.pxi" from aiohttp cimport _find_header ALLOWED_UPGRADES = frozenset({"websocket"}) DEF DEFAULT_FREELIST_SIZE = 250 cdef extern from "Python.h": int PyByteArray_Resize(object, Py_ssize_t) except -1 Py_ssize_t PyByteArray_Size(object) except -1 char* PyByteArray_AsString(object) __all__ = ('HttpRequestParser', 'HttpResponseParser', 'RawRequestMessage', 'RawResponseMessage') cdef object URL = _URL cdef object URL_build = URL.build cdef object CIMultiDict = _CIMultiDict cdef object CIMultiDictProxy = _CIMultiDictProxy cdef object HttpVersion = _HttpVersion cdef object HttpVersion10 = _HttpVersion10 cdef object HttpVersion11 = _HttpVersion11 cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1 cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD cdef object StreamReader = _StreamReader cdef object DeflateBuffer = _DeflateBuffer cdef bytes EMPTY_BYTES = b"" # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-6 cdef tuple SINGLETON_HEADERS = ( hdrs.CONTENT_LENGTH, hdrs.CONTENT_LOCATION, hdrs.CONTENT_RANGE, hdrs.CONTENT_TYPE, hdrs.ETAG, hdrs.HOST, hdrs.MAX_FORWARDS, hdrs.SERVER, hdrs.TRANSFER_ENCODING, hdrs.USER_AGENT, ) cdef inline object extend(object buf, const char* at, size_t length): cdef Py_ssize_t s cdef char* ptr s = PyByteArray_Size(buf) PyByteArray_Resize(buf, s + length) ptr = PyByteArray_AsString(buf) memcpy(ptr + s, at, length) DEF METHODS_COUNT = 46; cdef list _http_method = [] for i in range(METHODS_COUNT): _http_method.append( cparser.llhttp_method_name( i).decode('ascii')) cdef inline str http_method_str(int i): if i < METHODS_COUNT: return _http_method[i] else: return "" cdef inline object find_header(bytes raw_header): cdef Py_ssize_t size cdef char *buf cdef int idx PyBytes_AsStringAndSize(raw_header, &buf, &size) idx = _find_header.find_header(buf, size) if idx == -1: return raw_header.decode('utf-8', 'surrogateescape') return headers[idx] @cython.freelist(DEFAULT_FREELIST_SIZE) cdef class RawRequestMessage: cdef readonly str method cdef readonly str path cdef readonly object version # HttpVersion cdef readonly object headers # CIMultiDict cdef readonly object raw_headers # tuple cdef readonly object should_close cdef readonly object compression cdef readonly object upgrade cdef readonly object chunked cdef readonly object url # yarl.URL def __init__(self, method, path, version, headers, raw_headers, should_close, compression, upgrade, chunked, url): self.method = method self.path = path self.version = version self.headers = headers self.raw_headers = raw_headers self.should_close = should_close self.compression = compression self.upgrade = upgrade self.chunked = chunked self.url = url def __repr__(self): info = [] info.append(("method", self.method)) info.append(("path", self.path)) info.append(("version", self.version)) info.append(("headers", self.headers)) info.append(("raw_headers", self.raw_headers)) info.append(("should_close", self.should_close)) info.append(("compression", self.compression)) info.append(("upgrade", self.upgrade)) info.append(("chunked", self.chunked)) info.append(("url", self.url)) sinfo = ', '.join(name + '=' + repr(val) for name, val in info) return '' def _replace(self, **dct): cdef RawRequestMessage ret ret = _new_request_message(self.method, self.path, self.version, self.headers, self.raw_headers, self.should_close, self.compression, self.upgrade, self.chunked, self.url) if "method" in dct: ret.method = dct["method"] if "path" in dct: ret.path = dct["path"] if "version" in dct: ret.version = dct["version"] if "headers" in dct: ret.headers = dct["headers"] if "raw_headers" in dct: ret.raw_headers = dct["raw_headers"] if "should_close" in dct: ret.should_close = dct["should_close"] if "compression" in dct: ret.compression = dct["compression"] if "upgrade" in dct: ret.upgrade = dct["upgrade"] if "chunked" in dct: ret.chunked = dct["chunked"] if "url" in dct: ret.url = dct["url"] return ret cdef _new_request_message(str method, str path, object version, object headers, object raw_headers, bint should_close, object compression, bint upgrade, bint chunked, object url): cdef RawRequestMessage ret ret = RawRequestMessage.__new__(RawRequestMessage) ret.method = method ret.path = path ret.version = version ret.headers = headers ret.raw_headers = raw_headers ret.should_close = should_close ret.compression = compression ret.upgrade = upgrade ret.chunked = chunked ret.url = url return ret @cython.freelist(DEFAULT_FREELIST_SIZE) cdef class RawResponseMessage: cdef readonly object version # HttpVersion cdef readonly int code cdef readonly str reason cdef readonly object headers # CIMultiDict cdef readonly object raw_headers # tuple cdef readonly object should_close cdef readonly object compression cdef readonly object upgrade cdef readonly object chunked def __init__(self, version, code, reason, headers, raw_headers, should_close, compression, upgrade, chunked): self.version = version self.code = code self.reason = reason self.headers = headers self.raw_headers = raw_headers self.should_close = should_close self.compression = compression self.upgrade = upgrade self.chunked = chunked def __repr__(self): info = [] info.append(("version", self.version)) info.append(("code", self.code)) info.append(("reason", self.reason)) info.append(("headers", self.headers)) info.append(("raw_headers", self.raw_headers)) info.append(("should_close", self.should_close)) info.append(("compression", self.compression)) info.append(("upgrade", self.upgrade)) info.append(("chunked", self.chunked)) sinfo = ', '.join(name + '=' + repr(val) for name, val in info) return '' cdef _new_response_message(object version, int code, str reason, object headers, object raw_headers, bint should_close, object compression, bint upgrade, bint chunked): cdef RawResponseMessage ret ret = RawResponseMessage.__new__(RawResponseMessage) ret.version = version ret.code = code ret.reason = reason ret.headers = headers ret.raw_headers = raw_headers ret.should_close = should_close ret.compression = compression ret.upgrade = upgrade ret.chunked = chunked return ret @cython.internal cdef class HttpParser: cdef: cparser.llhttp_t* _cparser cparser.llhttp_settings_t* _csettings bytes _raw_name object _name bytes _raw_value bint _has_value int _header_name_size object _protocol object _loop object _timer size_t _max_line_size size_t _max_field_size size_t _max_headers bint _response_with_body bint _read_until_eof bint _started object _url bytearray _buf str _path str _reason list _headers list _raw_headers bint _upgraded list _messages object _payload bint _payload_error object _payload_exception object _last_error bint _auto_decompress int _limit str _content_encoding Py_buffer py_buf def __cinit__(self): self._cparser = \ PyMem_Malloc(sizeof(cparser.llhttp_t)) if self._cparser is NULL: raise MemoryError() self._csettings = \ PyMem_Malloc(sizeof(cparser.llhttp_settings_t)) if self._csettings is NULL: raise MemoryError() def __dealloc__(self): PyMem_Free(self._cparser) PyMem_Free(self._csettings) cdef _init( self, cparser.llhttp_type mode, object protocol, object loop, int limit, object timer=None, size_t max_line_size=8190, size_t max_headers=128, size_t max_field_size=8190, payload_exception=None, bint response_with_body=True, bint read_until_eof=False, bint auto_decompress=True, ): cparser.llhttp_settings_init(self._csettings) cparser.llhttp_init(self._cparser, mode, self._csettings) self._cparser.data = self self._cparser.content_length = 0 self._protocol = protocol self._loop = loop self._timer = timer self._buf = bytearray() self._payload = None self._payload_error = 0 self._payload_exception = payload_exception self._messages = [] self._raw_name = EMPTY_BYTES self._raw_value = EMPTY_BYTES self._has_value = False self._header_name_size = 0 self._max_line_size = max_line_size self._max_headers = max_headers self._max_field_size = max_field_size self._response_with_body = response_with_body self._read_until_eof = read_until_eof self._upgraded = False self._auto_decompress = auto_decompress self._content_encoding = None self._csettings.on_url = cb_on_url self._csettings.on_status = cb_on_status self._csettings.on_header_field = cb_on_header_field self._csettings.on_header_value = cb_on_header_value self._csettings.on_headers_complete = cb_on_headers_complete self._csettings.on_body = cb_on_body self._csettings.on_message_begin = cb_on_message_begin self._csettings.on_message_complete = cb_on_message_complete self._csettings.on_chunk_header = cb_on_chunk_header self._csettings.on_chunk_complete = cb_on_chunk_complete self._last_error = None self._limit = limit cdef _process_header(self): cdef str value if self._raw_name is not EMPTY_BYTES: name = find_header(self._raw_name) value = self._raw_value.decode('utf-8', 'surrogateescape') # reject null bytes in header values - matches the Python parser # check at http_parser.py. llhttp in lenient mode doesn't reject # these itself, so we need to catch them here. # ref: RFC 9110 section 5.5 (CTL chars forbidden in field values) if "\x00" in value: raise InvalidHeader(self._raw_value) self._headers.append((name, value)) if len(self._headers) > self._max_headers: raise BadHttpMessage("Too many headers received") if name is CONTENT_ENCODING: self._content_encoding = value self._has_value = False self._header_name_size = 0 self._raw_headers.append((self._raw_name, self._raw_value)) self._raw_name = EMPTY_BYTES self._raw_value = EMPTY_BYTES cdef _on_header_field(self, char* at, size_t length): if self._has_value: self._process_header() if self._raw_name is EMPTY_BYTES: self._raw_name = at[:length] else: self._raw_name += at[:length] cdef _on_header_value(self, char* at, size_t length): if self._raw_value is EMPTY_BYTES: self._raw_value = at[:length] else: self._raw_value += at[:length] self._has_value = True cdef _on_headers_complete(self): self._process_header() should_close = not cparser.llhttp_should_keep_alive(self._cparser) upgrade = self._cparser.upgrade chunked = self._cparser.flags & cparser.F_CHUNKED raw_headers = tuple(self._raw_headers) headers = CIMultiDictProxy(CIMultiDict(self._headers)) # https://www.rfc-editor.org/rfc/rfc9110.html#name-collected-abnf bad_hdr = next( (h for h in SINGLETON_HEADERS if len(headers.getall(h, ())) > 1), None, ) if bad_hdr is not None: raise BadHttpMessage(f"Duplicate '{bad_hdr}' header found.") if self._cparser.type == cparser.HTTP_REQUEST: h_upg = headers.get("upgrade", "") allowed = upgrade and h_upg.isascii() and h_upg.lower() in ALLOWED_UPGRADES if allowed or self._cparser.method == cparser.HTTP_CONNECT: self._upgraded = True else: if upgrade and self._cparser.status_code == 101: self._upgraded = True # do not support old websocket spec if SEC_WEBSOCKET_KEY1 in headers: raise InvalidHeader(SEC_WEBSOCKET_KEY1) encoding = None enc = self._content_encoding if enc is not None: self._content_encoding = None if enc.isascii() and enc.lower() in {"gzip", "deflate", "br", "zstd"}: encoding = enc if self._cparser.type == cparser.HTTP_REQUEST: method = http_method_str(self._cparser.method) msg = _new_request_message( method, self._path, self.http_version(), headers, raw_headers, should_close, encoding, upgrade, chunked, self._url) else: msg = _new_response_message( self.http_version(), self._cparser.status_code, self._reason, headers, raw_headers, should_close, encoding, upgrade, chunked) if ( ULLONG_MAX > self._cparser.content_length > 0 or chunked or self._cparser.method == cparser.HTTP_CONNECT or (self._cparser.status_code >= 199 and self._cparser.content_length == 0 and self._read_until_eof) ): payload = StreamReader( self._protocol, timer=self._timer, loop=self._loop, limit=self._limit) else: payload = EMPTY_PAYLOAD self._payload = payload if encoding is not None and self._auto_decompress: self._payload = DeflateBuffer(payload, encoding) if not self._response_with_body: payload = EMPTY_PAYLOAD self._messages.append((msg, payload)) cdef _on_message_complete(self): self._payload.feed_eof() self._payload = None cdef _on_chunk_header(self): self._payload.begin_http_chunk_receiving() cdef _on_chunk_complete(self): self._payload.end_http_chunk_receiving() cdef object _on_status_complete(self): pass cdef inline http_version(self): cdef cparser.llhttp_t* parser = self._cparser if parser.http_major == 1: if parser.http_minor == 0: return HttpVersion10 elif parser.http_minor == 1: return HttpVersion11 return HttpVersion(parser.http_major, parser.http_minor) ### Public API ### def feed_eof(self): cdef bytes desc if self._payload is not None: if self._cparser.flags & cparser.F_CHUNKED: raise TransferEncodingError( "Not enough data to satisfy transfer length header.") elif self._cparser.flags & cparser.F_CONTENT_LENGTH: raise ContentLengthError( "Not enough data to satisfy content length header.") elif cparser.llhttp_get_errno(self._cparser) != cparser.HPE_OK: desc = cparser.llhttp_get_error_reason(self._cparser) raise PayloadEncodingError(desc.decode('latin-1')) else: self._payload.feed_eof() elif self._started: self._on_headers_complete() if self._messages: return self._messages[-1][0] def feed_data(self, data): cdef: size_t data_len size_t nb cdef cparser.llhttp_errno_t errno PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE) data_len = self.py_buf.len errno = cparser.llhttp_execute( self._cparser, self.py_buf.buf, data_len) if errno is cparser.HPE_PAUSED_UPGRADE: cparser.llhttp_resume_after_upgrade(self._cparser) nb = cparser.llhttp_get_error_pos(self._cparser) - self.py_buf.buf PyBuffer_Release(&self.py_buf) if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE): if self._payload_error == 0: if self._last_error is not None: ex = self._last_error self._last_error = None else: after = cparser.llhttp_get_error_pos(self._cparser) before = data[:after - self.py_buf.buf] after_b = after.split(b"\r\n", 1)[0] before = before.rsplit(b"\r\n", 1)[-1] data = before + after_b pointer = " " * (len(repr(before))-1) + "^" ex = parser_error_from_errno(self._cparser, data, pointer) self._payload = None raise ex if self._messages: messages = self._messages self._messages = [] else: messages = () if self._upgraded: return messages, True, data[nb:] else: return messages, False, b"" def set_upgraded(self, val): self._upgraded = val cdef class HttpRequestParser(HttpParser): def __init__( self, protocol, loop, int limit, timer=None, size_t max_line_size=8190, size_t max_headers=128, size_t max_field_size=8190, payload_exception=None, bint response_with_body=True, bint read_until_eof=False, bint auto_decompress=True, ): self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer, max_line_size, max_headers, max_field_size, payload_exception, response_with_body, read_until_eof, auto_decompress) cdef object _on_status_complete(self): cdef int idx1, idx2 if not self._buf: return self._path = self._buf.decode('utf-8', 'surrogateescape') try: idx3 = len(self._path) if self._cparser.method == cparser.HTTP_CONNECT: # authority-form, # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3 self._url = URL.build(authority=self._path, encoded=True) elif idx3 > 1 and self._path[0] == '/': # origin-form, # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1 idx1 = self._path.find("?") if idx1 == -1: query = "" idx2 = self._path.find("#") if idx2 == -1: path = self._path fragment = "" else: path = self._path[0: idx2] fragment = self._path[idx2+1:] else: path = self._path[0:idx1] idx1 += 1 idx2 = self._path.find("#", idx1+1) if idx2 == -1: query = self._path[idx1:] fragment = "" else: query = self._path[idx1: idx2] fragment = self._path[idx2+1:] self._url = URL.build( path=path, query_string=query, fragment=fragment, encoded=True, ) else: # absolute-form for proxy maybe, # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2 self._url = URL(self._path, encoded=True) finally: PyByteArray_Resize(self._buf, 0) cdef class HttpResponseParser(HttpParser): def __init__( self, protocol, loop, int limit, timer=None, size_t max_line_size=8190, size_t max_headers=128, size_t max_field_size=8190, payload_exception=None, bint response_with_body=True, bint read_until_eof=False, bint auto_decompress=True ): self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer, max_line_size, max_headers, max_field_size, payload_exception, response_with_body, read_until_eof, auto_decompress) # Use strict parsing on dev mode, so users are warned about broken servers. if not DEBUG: cparser.llhttp_set_lenient_headers(self._cparser, 1) cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1) cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1) cdef object _on_status_complete(self): if self._buf: self._reason = self._buf.decode('utf-8', 'surrogateescape') PyByteArray_Resize(self._buf, 0) else: self._reason = self._reason or '' cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1: cdef HttpParser pyparser = parser.data pyparser._started = True pyparser._headers = [] pyparser._raw_headers = [] PyByteArray_Resize(pyparser._buf, 0) pyparser._path = None pyparser._reason = None return 0 cdef int cb_on_url(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data try: if length > pyparser._max_line_size: status = pyparser._buf + at[:length] raise LineTooLong(status[:100] + b"...", pyparser._max_line_size) extend(pyparser._buf, at, length) except BaseException as ex: pyparser._last_error = ex return -1 else: return 0 cdef int cb_on_status(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data try: if length > pyparser._max_line_size: reason = pyparser._buf + at[:length] raise LineTooLong(reason[:100] + b"...", pyparser._max_line_size) extend(pyparser._buf, at, length) except BaseException as ex: pyparser._last_error = ex return -1 else: return 0 cdef int cb_on_header_field(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data cdef Py_ssize_t size try: pyparser._on_status_complete() size = len(pyparser._raw_name) + length if size > pyparser._max_field_size: name = pyparser._raw_name + at[:length] raise LineTooLong(name[:100] + b"...", pyparser._max_field_size) pyparser._header_name_size = size pyparser._on_header_field(at, length) except BaseException as ex: pyparser._last_error = ex return -1 else: return 0 cdef int cb_on_header_value(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data cdef Py_ssize_t size try: size = len(pyparser._raw_value) + length if pyparser._header_name_size + size > pyparser._max_field_size: value = pyparser._raw_value + at[:length] raise LineTooLong(value[:100] + b"...", pyparser._max_field_size) pyparser._on_header_value(at, length) except BaseException as ex: pyparser._last_error = ex return -1 else: return 0 cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1: cdef HttpParser pyparser = parser.data try: pyparser._on_status_complete() pyparser._on_headers_complete() except BaseException as exc: pyparser._last_error = exc return -1 else: if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT: return 2 else: return 0 cdef int cb_on_body(cparser.llhttp_t* parser, const char *at, size_t length) except -1: cdef HttpParser pyparser = parser.data cdef bytes body = at[:length] try: pyparser._payload.feed_data(body) except BaseException as underlying_exc: reraised_exc = underlying_exc if pyparser._payload_exception is not None: reraised_exc = pyparser._payload_exception(str(underlying_exc)) set_exception(pyparser._payload, reraised_exc, underlying_exc) pyparser._payload_error = 1 return -1 else: return 0 cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1: cdef HttpParser pyparser = parser.data try: pyparser._started = False pyparser._on_message_complete() except BaseException as exc: pyparser._last_error = exc return -1 else: return 0 cdef int cb_on_chunk_header(cparser.llhttp_t* parser) except -1: cdef HttpParser pyparser = parser.data try: pyparser._on_chunk_header() except BaseException as exc: pyparser._last_error = exc return -1 else: return 0 cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1: cdef HttpParser pyparser = parser.data try: pyparser._on_chunk_complete() except BaseException as exc: pyparser._last_error = exc return -1 else: return 0 cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer): cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser) cdef bytes desc = cparser.llhttp_get_error_reason(parser) err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer) if errno in {cparser.HPE_CB_MESSAGE_BEGIN, cparser.HPE_CB_HEADERS_COMPLETE, cparser.HPE_CB_MESSAGE_COMPLETE, cparser.HPE_CB_CHUNK_HEADER, cparser.HPE_CB_CHUNK_COMPLETE, cparser.HPE_INVALID_CONSTANT, cparser.HPE_INVALID_HEADER_TOKEN, cparser.HPE_INVALID_CONTENT_LENGTH, cparser.HPE_INVALID_CHUNK_SIZE, cparser.HPE_INVALID_EOF_STATE, cparser.HPE_INVALID_TRANSFER_ENCODING}: return BadHttpMessage(err_msg) elif errno == cparser.HPE_INVALID_METHOD: return BadHttpMethod(error=err_msg) elif errno in {cparser.HPE_INVALID_STATUS, cparser.HPE_INVALID_VERSION}: return BadStatusLine(error=err_msg) elif errno == cparser.HPE_INVALID_URL: return InvalidURLError(err_msg) return BadHttpMessage(err_msg) ================================================ FILE: aiohttp/_http_writer.pyx ================================================ from cpython.bytes cimport PyBytes_FromStringAndSize from cpython.exc cimport PyErr_NoMemory from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc from cpython.object cimport PyObject_Str from libc.stdint cimport uint8_t, uint64_t from libc.string cimport memcpy from multidict import istr DEF BUF_SIZE = 16 * 1024 # 16KiB cdef object _istr = istr # ----------------- writer --------------------------- cdef struct Writer: char *buf Py_ssize_t size Py_ssize_t pos bint heap_allocated cdef inline void _init_writer(Writer* writer, char *buf): writer.buf = buf writer.size = BUF_SIZE writer.pos = 0 writer.heap_allocated = 0 cdef inline void _release_writer(Writer* writer): if writer.heap_allocated: PyMem_Free(writer.buf) cdef inline int _write_byte(Writer* writer, uint8_t ch): cdef char * buf cdef Py_ssize_t size if writer.pos == writer.size: # reallocate size = writer.size + BUF_SIZE if not writer.heap_allocated: buf = PyMem_Malloc(size) if buf == NULL: PyErr_NoMemory() return -1 memcpy(buf, writer.buf, writer.size) else: buf = PyMem_Realloc(writer.buf, size) if buf == NULL: PyErr_NoMemory() return -1 writer.buf = buf writer.size = size writer.heap_allocated = 1 writer.buf[writer.pos] = ch writer.pos += 1 return 0 cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol): cdef uint64_t utf = symbol if utf < 0x80: return _write_byte(writer, utf) elif utf < 0x800: if _write_byte(writer, (0xc0 | (utf >> 6))) < 0: return -1 return _write_byte(writer, (0x80 | (utf & 0x3f))) elif 0xD800 <= utf <= 0xDFFF: # surogate pair, ignored return 0 elif utf < 0x10000: if _write_byte(writer, (0xe0 | (utf >> 12))) < 0: return -1 if _write_byte(writer, (0x80 | ((utf >> 6) & 0x3f))) < 0: return -1 return _write_byte(writer, (0x80 | (utf & 0x3f))) elif utf > 0x10FFFF: # symbol is too large return 0 else: if _write_byte(writer, (0xf0 | (utf >> 18))) < 0: return -1 if _write_byte(writer, (0x80 | ((utf >> 12) & 0x3f))) < 0: return -1 if _write_byte(writer, (0x80 | ((utf >> 6) & 0x3f))) < 0: return -1 return _write_byte(writer, (0x80 | (utf & 0x3f))) cdef inline int _write_str(Writer* writer, str s): cdef Py_UCS4 ch for ch in s: if _write_utf8(writer, ch) < 0: return -1 cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s): cdef Py_UCS4 ch cdef str out_str if type(s) is str: out_str = s elif type(s) is _istr: out_str = PyObject_Str(s) elif not isinstance(s, str): raise TypeError("Cannot serialize non-str key {!r}".format(s)) else: out_str = str(s) for ch in out_str: if ch in {0x0D, 0x0A, 0x00}: raise ValueError( "Newline, carriage return, or null byte detected in headers. " "Potential header injection attack." ) if _write_utf8(writer, ch) < 0: return -1 # --------------- _serialize_headers ---------------------- def _serialize_headers(str status_line, headers): cdef Writer writer cdef object key cdef object val cdef char buf[BUF_SIZE] _init_writer(&writer, buf) try: if _write_str_raise_on_nlcr(&writer, status_line) < 0: raise if _write_byte(&writer, b'\r') < 0: raise if _write_byte(&writer, b'\n') < 0: raise for key, val in headers.items(): if _write_str_raise_on_nlcr(&writer, key) < 0: raise if _write_byte(&writer, b':') < 0: raise if _write_byte(&writer, b' ') < 0: raise if _write_str_raise_on_nlcr(&writer, val) < 0: raise if _write_byte(&writer, b'\r') < 0: raise if _write_byte(&writer, b'\n') < 0: raise if _write_byte(&writer, b'\r') < 0: raise if _write_byte(&writer, b'\n') < 0: raise return PyBytes_FromStringAndSize(writer.buf, writer.pos) finally: _release_writer(&writer) ================================================ FILE: aiohttp/_websocket/__init__.py ================================================ """WebSocket protocol versions 13 and 8.""" ================================================ FILE: aiohttp/_websocket/helpers.py ================================================ """Helpers for WebSocket protocol versions 13 and 8.""" import functools import re from re import Pattern from struct import Struct from typing import TYPE_CHECKING, Final from ..helpers import NO_EXTENSIONS from .models import WSHandshakeError UNPACK_LEN3 = Struct("!Q").unpack_from UNPACK_CLOSE_CODE = Struct("!H").unpack PACK_LEN1 = Struct("!BB").pack PACK_LEN2 = Struct("!BBH").pack PACK_LEN3 = Struct("!BBQ").pack PACK_CLOSE_CODE = Struct("!H").pack PACK_RANDBITS = Struct("!L").pack MSG_SIZE: Final[int] = 2**14 MASK_LEN: Final[int] = 4 WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" # Used by _websocket_mask_python @functools.lru_cache def _xor_table() -> list[bytes]: return [bytes(a ^ b for a in range(256)) for b in range(256)] def _websocket_mask_python(mask: bytes, data: bytearray) -> None: """Websocket masking function. `mask` is a `bytes` object of length 4; `data` is a `bytearray` object of any length. The contents of `data` are masked with `mask`, as specified in section 5.3 of RFC 6455. Note that this function mutates the `data` argument. This pure-python implementation may be replaced by an optimized version when available. """ assert isinstance(data, bytearray), data assert len(mask) == 4, mask if data: _XOR_TABLE = _xor_table() a, b, c, d = (_XOR_TABLE[n] for n in mask) data[::4] = data[::4].translate(a) data[1::4] = data[1::4].translate(b) data[2::4] = data[2::4].translate(c) data[3::4] = data[3::4].translate(d) if TYPE_CHECKING or NO_EXTENSIONS: websocket_mask = _websocket_mask_python else: try: from .mask import _websocket_mask_cython # type: ignore[import-not-found] websocket_mask = _websocket_mask_cython except ImportError: # pragma: no cover websocket_mask = _websocket_mask_python _WS_EXT_RE: Final[Pattern[str]] = re.compile( r"^(?:;\s*(?:" r"(server_no_context_takeover)|" r"(client_no_context_takeover)|" r"(server_max_window_bits(?:=(\d+))?)|" r"(client_max_window_bits(?:=(\d+))?)))*$" ) _WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") def ws_ext_parse(extstr: str | None, isserver: bool = False) -> tuple[int, bool]: if not extstr: return 0, False compress = 0 notakeover = False for ext in _WS_EXT_RE_SPLIT.finditer(extstr): defext = ext.group(1) # Return compress = 15 when get `permessage-deflate` if not defext: compress = 15 break match = _WS_EXT_RE.match(defext) if match: compress = 15 if isserver: # Server never fail to detect compress handshake. # Server does not need to send max wbit to client if match.group(4): compress = int(match.group(4)) # Group3 must match if group4 matches # Compress wbit 8 does not support in zlib # If compress level not support, # CONTINUE to next extension if compress > 15 or compress < 9: compress = 0 continue if match.group(1): notakeover = True # Ignore regex group 5 & 6 for client_max_window_bits break else: if match.group(6): compress = int(match.group(6)) # Group5 must match if group6 matches # Compress wbit 8 does not support in zlib # If compress level not support, # FAIL the parse progress if compress > 15 or compress < 9: raise WSHandshakeError("Invalid window size") if match.group(2): notakeover = True # Ignore regex group 5 & 6 for client_max_window_bits break # Return Fail if client side and not match elif not isserver: raise WSHandshakeError("Extension for deflate not supported" + ext.group(1)) return compress, notakeover def ws_ext_gen( compress: int = 15, isserver: bool = False, server_notakeover: bool = False ) -> str: # client_notakeover=False not used for server # compress wbit 8 does not support in zlib if compress < 9 or compress > 15: raise ValueError( "Compress wbits must between 9 and 15, zlib does not support wbits=8" ) enabledext = ["permessage-deflate"] if not isserver: enabledext.append("client_max_window_bits") if compress < 15: enabledext.append("server_max_window_bits=" + str(compress)) if server_notakeover: enabledext.append("server_no_context_takeover") # if client_notakeover: # enabledext.append('client_no_context_takeover') return "; ".join(enabledext) ================================================ FILE: aiohttp/_websocket/mask.pxd ================================================ """Cython declarations for websocket masking.""" cpdef void _websocket_mask_cython(bytes mask, bytearray data) ================================================ FILE: aiohttp/_websocket/mask.pyx ================================================ from cpython cimport PyBytes_AsString #from cpython cimport PyByteArray_AsString # cython still not exports that cdef extern from "Python.h": char* PyByteArray_AsString(bytearray ba) except NULL from libc.stdint cimport uint32_t, uint64_t, uintmax_t cpdef void _websocket_mask_cython(bytes mask, bytearray data): """Note, this function mutates its `data` argument """ cdef: Py_ssize_t data_len, i # bit operations on signed integers are implementation-specific unsigned char * in_buf const unsigned char * mask_buf uint32_t uint32_msk uint64_t uint64_msk assert len(mask) == 4 data_len = len(data) in_buf = PyByteArray_AsString(data) mask_buf = PyBytes_AsString(mask) uint32_msk = (mask_buf)[0] # TODO: align in_data ptr to achieve even faster speeds # does it need in python ?! malloc() always aligns to sizeof(long) bytes if sizeof(size_t) >= 8: uint64_msk = uint32_msk uint64_msk = (uint64_msk << 32) | uint32_msk while data_len >= 8: (in_buf)[0] ^= uint64_msk in_buf += 8 data_len -= 8 while data_len >= 4: (in_buf)[0] ^= uint32_msk in_buf += 4 data_len -= 4 for i in range(0, data_len): in_buf[i] ^= mask_buf[i] ================================================ FILE: aiohttp/_websocket/models.py ================================================ """Models for WebSocket protocol versions 13 and 8.""" import json from collections.abc import Callable from enum import IntEnum from typing import Any, Final, Literal, NamedTuple, cast WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF]) class WSCloseCode(IntEnum): OK = 1000 GOING_AWAY = 1001 PROTOCOL_ERROR = 1002 UNSUPPORTED_DATA = 1003 ABNORMAL_CLOSURE = 1006 INVALID_TEXT = 1007 POLICY_VIOLATION = 1008 MESSAGE_TOO_BIG = 1009 MANDATORY_EXTENSION = 1010 INTERNAL_ERROR = 1011 SERVICE_RESTART = 1012 TRY_AGAIN_LATER = 1013 BAD_GATEWAY = 1014 class WSMsgType(IntEnum): # websocket spec types CONTINUATION = 0x0 TEXT = 0x1 BINARY = 0x2 PING = 0x9 PONG = 0xA CLOSE = 0x8 # aiohttp specific types CLOSING = 0x100 CLOSED = 0x101 ERROR = 0x102 class WSMessageContinuation(NamedTuple): data: bytes size: int extra: str | None = None type: Literal[WSMsgType.CONTINUATION] = WSMsgType.CONTINUATION class WSMessageText(NamedTuple): data: str size: int extra: str | None = None type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT def json( self, *, loads: Callable[[str | bytes | bytearray], Any] = json.loads ) -> Any: """Return parsed JSON data.""" return loads(self.data) class WSMessageTextBytes(NamedTuple): """WebSocket TEXT message with raw bytes (no UTF-8 decoding).""" data: bytes size: int extra: str | None = None type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT def json(self, *, loads: Callable[[bytes], Any] = json.loads) -> Any: """Return parsed JSON data.""" return loads(self.data) class WSMessageBinary(NamedTuple): data: bytes size: int extra: str | None = None type: Literal[WSMsgType.BINARY] = WSMsgType.BINARY def json( self, *, loads: Callable[[str | bytes | bytearray], Any] = json.loads ) -> Any: """Return parsed JSON data.""" return loads(self.data) class WSMessagePing(NamedTuple): data: bytes size: int extra: str | None = None type: Literal[WSMsgType.PING] = WSMsgType.PING class WSMessagePong(NamedTuple): data: bytes size: int extra: str | None = None type: Literal[WSMsgType.PONG] = WSMsgType.PONG class WSMessageClose(NamedTuple): data: int size: int extra: str | None = None type: Literal[WSMsgType.CLOSE] = WSMsgType.CLOSE class WSMessageClosing(NamedTuple): data: None = None size: int = 0 extra: str | None = None type: Literal[WSMsgType.CLOSING] = WSMsgType.CLOSING class WSMessageClosed(NamedTuple): data: None = None size: int = 0 extra: str | None = None type: Literal[WSMsgType.CLOSED] = WSMsgType.CLOSED class WSMessageError(NamedTuple): data: BaseException size: int = 0 extra: str | None = None type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR # Base message types (excluding TEXT variants) _WSMessageBase = ( WSMessageContinuation | WSMessageBinary | WSMessagePing | WSMessagePong | WSMessageClose | WSMessageClosing | WSMessageClosed | WSMessageError ) # All message types WSMessage = _WSMessageBase | WSMessageText | WSMessageTextBytes # Message type when decode_text=True (default) - TEXT messages have str data WSMessageDecodeText = _WSMessageBase | WSMessageText # Message type when decode_text=False - TEXT messages have bytes data WSMessageNoDecodeText = _WSMessageBase | WSMessageTextBytes WS_CLOSED_MESSAGE = WSMessageClosed() WS_CLOSING_MESSAGE = WSMessageClosing() class WebSocketError(Exception): """WebSocket protocol parser error.""" def __init__(self, code: int, message: str) -> None: self.code = code super().__init__(code, message) def __str__(self) -> str: return cast(str, self.args[1]) class WSHandshakeError(Exception): """WebSocket protocol handshake error.""" ================================================ FILE: aiohttp/_websocket/reader.py ================================================ """Reader for WebSocket protocol versions 13 and 8.""" from typing import TYPE_CHECKING from ..helpers import NO_EXTENSIONS if TYPE_CHECKING or NO_EXTENSIONS: from .reader_py import ( WebSocketDataQueue as WebSocketDataQueuePython, WebSocketReader as WebSocketReaderPython, ) WebSocketReader = WebSocketReaderPython WebSocketDataQueue = WebSocketDataQueuePython else: try: from .reader_c import ( # type: ignore[import-not-found] WebSocketDataQueue as WebSocketDataQueueCython, WebSocketReader as WebSocketReaderCython, ) WebSocketReader = WebSocketReaderCython WebSocketDataQueue = WebSocketDataQueueCython except ImportError: # pragma: no cover from .reader_py import ( WebSocketDataQueue as WebSocketDataQueuePython, WebSocketReader as WebSocketReaderPython, ) WebSocketReader = WebSocketReaderPython WebSocketDataQueue = WebSocketDataQueuePython ================================================ FILE: aiohttp/_websocket/reader_c.pxd ================================================ import cython from .mask cimport _websocket_mask_cython as websocket_mask cdef unsigned int READ_HEADER cdef unsigned int READ_PAYLOAD_LENGTH cdef unsigned int READ_PAYLOAD_MASK cdef unsigned int READ_PAYLOAD cdef int OP_CODE_NOT_SET cdef int OP_CODE_CONTINUATION cdef int OP_CODE_TEXT cdef int OP_CODE_BINARY cdef int OP_CODE_CLOSE cdef int OP_CODE_PING cdef int OP_CODE_PONG cdef int COMPRESSED_NOT_SET cdef int COMPRESSED_FALSE cdef int COMPRESSED_TRUE cdef object UNPACK_LEN3 cdef object UNPACK_CLOSE_CODE cdef object TUPLE_NEW cdef object WSMsgType cdef object WSMessageText cdef object WSMessageTextBytes cdef object WSMessageBinary cdef object WSMessagePing cdef object WSMessagePong cdef object WSMessageClose cdef object WS_MSG_TYPE_TEXT cdef object WS_MSG_TYPE_BINARY cdef set ALLOWED_CLOSE_CODES cdef set MESSAGE_TYPES_WITH_CONTENT cdef tuple EMPTY_FRAME cdef tuple EMPTY_FRAME_ERROR cdef class WebSocketDataQueue: cdef unsigned int _size cdef public object _protocol cdef unsigned int _limit cdef object _loop cdef bint _eof cdef object _waiter cdef object _exception cdef public object _buffer cdef object _get_buffer cdef object _put_buffer cdef void _release_waiter(self) @cython.locals(size="unsigned int") cpdef void feed_data(self, object data) @cython.locals(size="unsigned int") cdef _read_from_buffer(self) cdef class WebSocketReader: cdef WebSocketDataQueue queue cdef unsigned int _max_msg_size cdef bint _decode_text cdef Exception _exc cdef bytearray _partial cdef unsigned int _state cdef int _opcode cdef bint _frame_fin cdef int _frame_opcode cdef list _payload_fragments cdef Py_ssize_t _frame_payload_len cdef bytes _tail cdef bint _has_mask cdef bytes _frame_mask cdef Py_ssize_t _payload_bytes_to_read cdef unsigned int _payload_len_flag cdef int _compressed cdef object _decompressobj cdef bint _compress cpdef tuple feed_data(self, object data) @cython.locals( is_continuation=bint, fin=bint, has_partial=bint, payload_merged=bytes, ) cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except * @cython.locals( start_pos=Py_ssize_t, data_len=Py_ssize_t, length=Py_ssize_t, chunk_size=Py_ssize_t, chunk_len=Py_ssize_t, data_len=Py_ssize_t, data_cstr="const unsigned char *", first_byte="unsigned char", second_byte="unsigned char", f_start_pos=Py_ssize_t, f_end_pos=Py_ssize_t, has_mask=bint, fin=bint, had_fragments=Py_ssize_t, payload_bytearray=bytearray, ) cpdef void _feed_data(self, bytes data) except * ================================================ FILE: aiohttp/_websocket/reader_py.py ================================================ """Reader for WebSocket protocol versions 13 and 8.""" import asyncio import builtins from collections import deque from typing import Final from ..base_protocol import BaseProtocol from ..compression_utils import ZLibDecompressor from ..helpers import _EXC_SENTINEL, set_exception from ..streams import EofStream from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask from .models import ( WS_DEFLATE_TRAILING, WebSocketError, WSCloseCode, WSMessage, WSMessageBinary, WSMessageClose, WSMessagePing, WSMessagePong, WSMessageText, WSMessageTextBytes, WSMsgType, ) ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode} # States for the reader, used to parse the WebSocket frame # integer values are used so they can be cythonized READ_HEADER = 1 READ_PAYLOAD_LENGTH = 2 READ_PAYLOAD_MASK = 3 READ_PAYLOAD = 4 WS_MSG_TYPE_BINARY = WSMsgType.BINARY WS_MSG_TYPE_TEXT = WSMsgType.TEXT # WSMsgType values unpacked so they can by cythonized to ints OP_CODE_NOT_SET = -1 OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value OP_CODE_TEXT = WSMsgType.TEXT.value OP_CODE_BINARY = WSMsgType.BINARY.value OP_CODE_CLOSE = WSMsgType.CLOSE.value OP_CODE_PING = WSMsgType.PING.value OP_CODE_PONG = WSMsgType.PONG.value EMPTY_FRAME_ERROR = (True, b"") EMPTY_FRAME = (False, b"") COMPRESSED_NOT_SET = -1 COMPRESSED_FALSE = 0 COMPRESSED_TRUE = 1 TUPLE_NEW = tuple.__new__ cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd class WebSocketDataQueue: """WebSocketDataQueue resumes and pauses an underlying stream. It is a destination for WebSocket data. """ def __init__( self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop ) -> None: self._size = 0 self._protocol = protocol self._limit = limit * 2 self._loop = loop self._eof = False self._waiter: asyncio.Future[None] | None = None self._exception: type[BaseException] | BaseException | None = None self._buffer: deque[WSMessage] = deque() self._get_buffer = self._buffer.popleft self._put_buffer = self._buffer.append def is_eof(self) -> bool: return self._eof def exception(self) -> type[BaseException] | BaseException | None: return self._exception def set_exception( self, exc: type[BaseException] | BaseException, exc_cause: builtins.BaseException = _EXC_SENTINEL, ) -> None: self._eof = True self._exception = exc if (waiter := self._waiter) is not None: self._waiter = None set_exception(waiter, exc, exc_cause) def _release_waiter(self) -> None: if (waiter := self._waiter) is None: return self._waiter = None if not waiter.done(): waiter.set_result(None) def feed_eof(self) -> None: self._eof = True self._release_waiter() self._exception = None # Break cyclic references def feed_data(self, data: "WSMessage") -> None: size = data.size self._size += size self._put_buffer(data) self._release_waiter() if self._size > self._limit and not self._protocol._reading_paused: self._protocol.pause_reading() async def read(self) -> WSMessage: if not self._buffer and not self._eof: assert not self._waiter self._waiter = self._loop.create_future() try: await self._waiter except (asyncio.CancelledError, asyncio.TimeoutError): self._waiter = None raise return self._read_from_buffer() def _read_from_buffer(self) -> WSMessage: if self._buffer: data = self._get_buffer() size = data.size self._size -= size if self._size < self._limit and self._protocol._reading_paused: self._protocol.resume_reading() return data if self._exception is not None: raise self._exception raise EofStream class WebSocketReader: def __init__( self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True, decode_text: bool = True, ) -> None: self.queue = queue self._max_msg_size = max_msg_size self._decode_text = decode_text self._exc: Exception | None = None self._partial = bytearray() self._state = READ_HEADER self._opcode: int = OP_CODE_NOT_SET self._frame_fin = False self._frame_opcode: int = OP_CODE_NOT_SET self._payload_fragments: list[bytes] = [] self._frame_payload_len = 0 self._tail: bytes = b"" self._has_mask = False self._frame_mask: bytes | None = None self._payload_bytes_to_read = 0 self._payload_len_flag = 0 self._compressed: int = COMPRESSED_NOT_SET self._decompressobj: ZLibDecompressor | None = None self._compress = compress def feed_eof(self) -> None: self.queue.feed_eof() # data can be bytearray on Windows because proactor event loop uses bytearray # and asyncio types this to Union[bytes, bytearray, memoryview] so we need # coerce data to bytes if it is not def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]: if type(data) is not bytes: data = bytes(data) if self._exc is not None: return True, data try: self._feed_data(data) except Exception as exc: self._exc = exc set_exception(self.queue, exc) return EMPTY_FRAME_ERROR return EMPTY_FRAME def _handle_frame( self, fin: bool, opcode: int | cython_int, # Union intended: Cython pxd uses C int payload: bytes | bytearray, compressed: int | cython_int, # Union intended: Cython pxd uses C int ) -> None: msg: WSMessage if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: # Validate continuation frames before processing if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "Continuation frame for non started message", ) # load text/binary if not fin: # got partial frame payload if opcode != OP_CODE_CONTINUATION: self._opcode = opcode self._partial += payload if self._max_msg_size and len(self._partial) >= self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, f"Message size {len(self._partial)} " f"exceeds limit {self._max_msg_size}", ) return has_partial = bool(self._partial) if opcode == OP_CODE_CONTINUATION: opcode = self._opcode self._opcode = OP_CODE_NOT_SET # previous frame was non finished # we should get continuation opcode elif has_partial: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "The opcode in non-fin frame is expected " f"to be zero, got {opcode!r}", ) assembled_payload: bytes | bytearray if has_partial: assembled_payload = self._partial + payload self._partial.clear() else: assembled_payload = payload if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, f"Message size {len(assembled_payload)} " f"exceeds limit {self._max_msg_size}", ) # Decompress process must to be done after all packets # received. if compressed: if not self._decompressobj: self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) # XXX: It's possible that the zlib backend (isal is known to # do this, maybe others too?) will return max_length bytes, # but internally buffer more data such that the payload is # >max_length, so we return one extra byte and if we're able # to do that, then the message is too big. payload_merged = self._decompressobj.decompress_sync( assembled_payload + WS_DEFLATE_TRAILING, ( self._max_msg_size + 1 if self._max_msg_size else self._max_msg_size ), ) if self._max_msg_size and len(payload_merged) > self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, f"Decompressed message exceeds size limit {self._max_msg_size}", ) elif type(assembled_payload) is bytes: payload_merged = assembled_payload else: payload_merged = bytes(assembled_payload) size = len(payload_merged) if opcode == OP_CODE_TEXT: if self._decode_text: try: text = payload_merged.decode("utf-8") except UnicodeDecodeError as exc: raise WebSocketError( WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" ) from exc # XXX: The Text and Binary messages here can be a performance # bottleneck, so we use tuple.__new__ to improve performance. # This is not type safe, but many tests should fail in # test_client_ws_functional.py if this is wrong. msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) else: # Return raw bytes for TEXT messages when decode_text=False msg = TUPLE_NEW( WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT) ) else: msg = TUPLE_NEW( WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) ) self.queue.feed_data(msg) elif opcode == OP_CODE_CLOSE: payload_len = len(payload) if payload_len >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, f"Invalid close code: {close_code}", ) try: close_message = payload[2:].decode("utf-8") except UnicodeDecodeError as exc: raise WebSocketError( WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" ) from exc msg = WSMessageClose( data=close_code, size=payload_len, extra=close_message ) elif payload: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, f"Invalid close frame: {fin} {opcode} {payload!r}", ) else: msg = WSMessageClose(data=0, size=payload_len, extra="") self.queue.feed_data(msg) elif opcode == OP_CODE_PING: self.queue.feed_data( WSMessagePing(data=bytes(payload), size=len(payload), extra="") ) elif opcode == OP_CODE_PONG: self.queue.feed_data( WSMessagePong(data=bytes(payload), size=len(payload), extra="") ) else: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" ) def _feed_data(self, data: bytes) -> None: """Return the next frame from the socket.""" if self._tail: data, self._tail = self._tail + data, b"" start_pos: int = 0 data_len = len(data) data_cstr = data while True: # read header if self._state == READ_HEADER: if data_len - start_pos < 2: break first_byte = data_cstr[start_pos] second_byte = data_cstr[start_pos + 1] start_pos += 2 fin = (first_byte >> 7) & 1 rsv1 = (first_byte >> 6) & 1 rsv2 = (first_byte >> 5) & 1 rsv3 = (first_byte >> 4) & 1 opcode = first_byte & 0xF # frame-fin = %x0 ; more frames of this message follow # / %x1 ; final frame of this message # frame-rsv1 = %x0 ; # 1 bit, MUST be 0 unless negotiated otherwise # frame-rsv2 = %x0 ; # 1 bit, MUST be 0 unless negotiated otherwise # frame-rsv3 = %x0 ; # 1 bit, MUST be 0 unless negotiated otherwise # # Remove rsv1 from this test for deflate development if rsv2 or rsv3 or (rsv1 and not self._compress): raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "Received frame with non-zero reserved bits", ) if opcode > 0x7 and fin == 0: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "Received fragmented control frame", ) has_mask = (second_byte >> 7) & 1 length = second_byte & 0x7F # Control frames MUST have a payload # length of 125 bytes or less if opcode > 0x7 and length > 125: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "Control frame payload cannot be larger than 125 bytes", ) # Set compress status if last package is FIN # OR set compress status if this is first fragment # Raise error if not first fragment with rsv1 = 0x1 if self._frame_fin or self._compressed == COMPRESSED_NOT_SET: self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE elif rsv1: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "Received frame with non-zero reserved bits", ) self._frame_fin = bool(fin) self._frame_opcode = opcode self._has_mask = bool(has_mask) self._payload_len_flag = length self._state = READ_PAYLOAD_LENGTH # read payload length if self._state == READ_PAYLOAD_LENGTH: len_flag = self._payload_len_flag if len_flag == 126: if data_len - start_pos < 2: break first_byte = data_cstr[start_pos] second_byte = data_cstr[start_pos + 1] start_pos += 2 self._payload_bytes_to_read = first_byte << 8 | second_byte elif len_flag > 126: if data_len - start_pos < 8: break self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0] start_pos += 8 else: self._payload_bytes_to_read = len_flag self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD # read payload mask if self._state == READ_PAYLOAD_MASK: if data_len - start_pos < 4: break self._frame_mask = data_cstr[start_pos : start_pos + 4] start_pos += 4 self._state = READ_PAYLOAD if self._state == READ_PAYLOAD: chunk_len = data_len - start_pos if self._payload_bytes_to_read >= chunk_len: f_end_pos = data_len self._payload_bytes_to_read -= chunk_len else: f_end_pos = start_pos + self._payload_bytes_to_read self._payload_bytes_to_read = 0 had_fragments = self._frame_payload_len self._frame_payload_len += f_end_pos - start_pos f_start_pos = start_pos start_pos = f_end_pos if self._payload_bytes_to_read != 0: # If we don't have a complete frame, we need to save the # data for the next call to feed_data. self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) break payload: bytes | bytearray if had_fragments: # We have to join the payload fragments get the payload self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) if self._has_mask: assert self._frame_mask is not None payload_bytearray = bytearray(b"".join(self._payload_fragments)) websocket_mask(self._frame_mask, payload_bytearray) payload = payload_bytearray else: payload = b"".join(self._payload_fragments) self._payload_fragments.clear() elif self._has_mask: assert self._frame_mask is not None payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment] if type(payload_bytearray) is not bytearray: # pragma: no branch # Cython will do the conversion for us # but we need to do it for Python and we # will always get here in Python payload_bytearray = bytearray(payload_bytearray) websocket_mask(self._frame_mask, payload_bytearray) payload = payload_bytearray else: payload = data_cstr[f_start_pos:f_end_pos] self._handle_frame( self._frame_fin, self._frame_opcode, payload, self._compressed ) self._frame_payload_len = 0 self._state = READ_HEADER # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b"" ================================================ FILE: aiohttp/_websocket/writer.py ================================================ """WebSocket protocol versions 13 and 8.""" import asyncio import random import sys from functools import partial from typing import Final from ..base_protocol import BaseProtocol from ..client_exceptions import ClientConnectionResetError from ..compression_utils import ZLibBackend, ZLibCompressor from .helpers import ( MASK_LEN, MSG_SIZE, PACK_CLOSE_CODE, PACK_LEN1, PACK_LEN2, PACK_LEN3, PACK_RANDBITS, websocket_mask, ) from .models import WS_DEFLATE_TRAILING, WSMsgType DEFAULT_LIMIT: Final[int] = 2**16 # WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames # Control frames (ping, pong, close) are never compressed WS_CONTROL_FRAME_OPCODE: Final[int] = 8 # For websockets, keeping latency low is extremely important as implementations # generally expect to be able to send and receive messages quickly. We use a # larger chunk size to reduce the number of executor calls and avoid task # creation overhead, since both are significant sources of latency when chunks # are small. A size of 16KiB was chosen as a balance between avoiding task # overhead and not blocking the event loop too long with synchronous compression. WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024 class WebSocketWriter: """WebSocket writer. The writer is responsible for sending messages to the client. It is created by the protocol when a connection is established. The writer should avoid implementing any application logic and should only be concerned with the low-level details of the WebSocket protocol. """ def __init__( self, protocol: BaseProtocol, transport: asyncio.Transport, *, use_mask: bool = False, limit: int = DEFAULT_LIMIT, random: random.Random = random.Random(), compress: int = 0, notakeover: bool = False, ) -> None: """Initialize a WebSocket writer.""" self.protocol = protocol self.transport = transport self.use_mask = use_mask self.get_random_bits = partial(random.getrandbits, 32) self.compress = compress self.notakeover = notakeover self._closing = False self._limit = limit self._output_size = 0 self._compressobj: ZLibCompressor | None = None self._send_lock = asyncio.Lock() self._background_tasks: set[asyncio.Task[None]] = set() async def send_frame( self, message: bytes, opcode: int, compress: int | None = None ) -> None: """Send a frame over the websocket with message as its payload.""" if self._closing and not (opcode & WSMsgType.CLOSE): raise ClientConnectionResetError("Cannot write to closing transport") if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE: # Non-compressed frames don't need lock or shield self._write_websocket_frame(message, opcode, 0) elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE: # Small compressed payloads - compress synchronously in event loop # We need the lock even though sync compression has no await points. # This prevents small frames from interleaving with large frames that # compress in the executor, avoiding compressor state corruption. async with self._send_lock: self._send_compressed_frame_sync(message, opcode, compress) else: # Large compressed frames need shield to prevent corruption # For large compressed frames, the entire compress+send # operation must be atomic. If cancelled after compression but # before send, the compressor state would be advanced but data # not sent, corrupting subsequent frames. # Create a task to shield from cancellation # The lock is acquired inside the shielded task so the entire # operation (lock + compress + send) completes atomically. # Use eager_start on Python 3.12+ to avoid scheduling overhead loop = asyncio.get_running_loop() coro = self._send_compressed_frame_async_locked(message, opcode, compress) if sys.version_info >= (3, 12): send_task = asyncio.Task(coro, loop=loop, eager_start=True) else: send_task = loop.create_task(coro) # Keep a strong reference to prevent garbage collection self._background_tasks.add(send_task) send_task.add_done_callback(self._background_tasks.discard) await asyncio.shield(send_task) # It is safe to return control to the event loop when using compression # after this point as we have already sent or buffered all the data. # Once we have written output_size up to the limit, we call the # drain helper which waits for the transport to be ready to accept # more data. This is a flow control mechanism to prevent the buffer # from growing too large. The drain helper will return right away # if the writer is not paused. if self._output_size > self._limit: self._output_size = 0 if self.protocol._paused: await self.protocol._drain_helper() def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None: """ Write a websocket frame to the transport. This method handles frame header construction, masking, and writing to transport. It does not handle compression or flow control - those are the responsibility of the caller. """ msg_length = len(message) use_mask = self.use_mask mask_bit = 0x80 if use_mask else 0 # Depending on the message length, the header is assembled differently. # The first byte is reserved for the opcode and the RSV bits. first_byte = 0x80 | rsv | opcode if msg_length < 126: header = PACK_LEN1(first_byte, msg_length | mask_bit) header_len = 2 elif msg_length < 65536: header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length) header_len = 4 else: header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length) header_len = 10 if self.transport.is_closing(): raise ClientConnectionResetError("Cannot write to closing transport") # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 # If we are using a mask, we need to generate it randomly # and apply it to the message before sending it. A mask is # a 32-bit value that is applied to the message using a # bitwise XOR operation. It is used to prevent certain types # of attacks on the websocket protocol. The mask is only used # when aiohttp is acting as a client. Servers do not use a mask. if use_mask: mask = PACK_RANDBITS(self.get_random_bits()) message_arr = bytearray(message) websocket_mask(mask, message_arr) self.transport.write(header + mask + message_arr) self._output_size += MASK_LEN elif msg_length > MSG_SIZE: self.transport.write(header) self.transport.write(message) else: self.transport.write(header + message) self._output_size += header_len + msg_length def _get_compressor(self, compress: int | None) -> ZLibCompressor: """Get or create a compressor object for the given compression level.""" if compress: # Do not set self._compress if compressing is for this frame return ZLibCompressor( level=ZLibBackend.Z_BEST_SPEED, wbits=-compress, max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, ) if not self._compressobj: self._compressobj = ZLibCompressor( level=ZLibBackend.Z_BEST_SPEED, wbits=-self.compress, max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, ) return self._compressobj def _send_compressed_frame_sync( self, message: bytes, opcode: int, compress: int | None ) -> None: """ Synchronous send for small compressed frames. This is used for small compressed payloads that compress synchronously in the event loop. Since there are no await points, this is inherently cancellation-safe. """ # RSV are the reserved bits in the frame header. They are used to # indicate that the frame is using an extension. # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 compressobj = self._get_compressor(compress) # (0x40) RSV1 is set for compressed frames # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 self._write_websocket_frame( ( compressobj.compress_sync(message) + compressobj.flush( ZLibBackend.Z_FULL_FLUSH if self.notakeover else ZLibBackend.Z_SYNC_FLUSH ) ).removesuffix(WS_DEFLATE_TRAILING), opcode, 0x40, ) async def _send_compressed_frame_async_locked( self, message: bytes, opcode: int, compress: int | None ) -> None: """ Async send for large compressed frames with lock. Acquires the lock and compresses large payloads asynchronously in the executor. The lock is held for the entire operation to ensure the compressor state is not corrupted by concurrent sends. MUST be run shielded from cancellation. If cancelled after compression but before sending, the compressor state would be advanced but data not sent, corrupting subsequent frames. """ async with self._send_lock: # RSV are the reserved bits in the frame header. They are used to # indicate that the frame is using an extension. # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 compressobj = self._get_compressor(compress) # (0x40) RSV1 is set for compressed frames # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 self._write_websocket_frame( ( await compressobj.compress(message) + compressobj.flush( ZLibBackend.Z_FULL_FLUSH if self.notakeover else ZLibBackend.Z_SYNC_FLUSH ) ).removesuffix(WS_DEFLATE_TRAILING), opcode, 0x40, ) async def close(self, code: int = 1000, message: bytes | str = b"") -> None: """Close the websocket, sending the specified code and message.""" if isinstance(message, str): message = message.encode("utf-8") try: await self.send_frame( PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE ) finally: self._closing = True ================================================ FILE: aiohttp/abc.py ================================================ import logging import socket from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Generator, Iterable, Sequence, Sized from http.cookies import BaseCookie, Morsel from typing import TYPE_CHECKING, Any, TypedDict from multidict import CIMultiDict from yarl import URL from ._cookie_helpers import parse_set_cookie_headers from .typedefs import LooseCookies if TYPE_CHECKING: from .web_app import Application from .web_exceptions import HTTPException from .web_request import BaseRequest, Request from .web_response import StreamResponse else: BaseRequest = Request = Application = StreamResponse = Any HTTPException = Any class AbstractRouter(ABC): def __init__(self) -> None: self._frozen = False def post_init(self, app: Application) -> None: """Post init stage. Not an abstract method for sake of backward compatibility, but if the router wants to be aware of the application it can override this. """ @property def frozen(self) -> bool: return self._frozen def freeze(self) -> None: """Freeze router.""" self._frozen = True @abstractmethod async def resolve(self, request: Request) -> "AbstractMatchInfo": """Return MATCH_INFO for given request""" class AbstractMatchInfo(ABC): __slots__ = () @property # pragma: no branch @abstractmethod def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]: """Execute matched request handler""" @property @abstractmethod def expect_handler( self, ) -> Callable[[Request], Awaitable[StreamResponse | None]]: """Expect handler for 100-continue processing""" @property # pragma: no branch @abstractmethod def http_exception(self) -> HTTPException | None: """HTTPException instance raised on router's resolving, or None""" @abstractmethod # pragma: no branch def get_info(self) -> dict[str, Any]: """Return a dict with additional info useful for introspection""" @property # pragma: no branch @abstractmethod def apps(self) -> tuple[Application, ...]: """Stack of nested applications. Top level application is left-most element. """ @abstractmethod def add_app(self, app: Application) -> None: """Add application to the nested apps stack.""" @abstractmethod def freeze(self) -> None: """Freeze the match info. The method is called after route resolution. After the call .add_app() is forbidden. """ class AbstractView(ABC): """Abstract class based view.""" def __init__(self, request: Request) -> None: self._request = request @property def request(self) -> Request: """Request instance.""" return self._request @abstractmethod def __await__(self) -> Generator[None, None, StreamResponse]: """Execute the view handler.""" class ResolveResult(TypedDict): """Resolve result. This is the result returned from an AbstractResolver's resolve method. :param hostname: The hostname that was provided. :param host: The IP address that was resolved. :param port: The port that was resolved. :param family: The address family that was resolved. :param proto: The protocol that was resolved. :param flags: The flags that were resolved. """ hostname: str host: str port: int family: int proto: int flags: int class AbstractResolver(ABC): """Abstract DNS resolver.""" @abstractmethod async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET ) -> list[ResolveResult]: """Return IP address for given hostname""" @abstractmethod async def close(self) -> None: """Release resolver""" ClearCookiePredicate = Callable[[Morsel[str]], bool] class AbstractCookieJar(Sized, Iterable[Morsel[str]]): """Abstract Cookie Jar.""" @property @abstractmethod def quote_cookie(self) -> bool: """Return True if cookies should be quoted.""" @abstractmethod def clear(self, predicate: ClearCookiePredicate | None = None) -> None: """Clear all cookies if no predicate is passed.""" @abstractmethod def clear_domain(self, domain: str) -> None: """Clear all cookies for domain and all subdomains.""" @abstractmethod def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: """Update cookies.""" def update_cookies_from_headers( self, headers: Sequence[str], response_url: URL ) -> None: """Update cookies from raw Set-Cookie headers.""" if headers and (cookies_to_update := parse_set_cookie_headers(headers)): self.update_cookies(cookies_to_update, response_url) @abstractmethod def filter_cookies(self, request_url: URL) -> BaseCookie[str]: """Return the jar's cookies filtered by their attributes.""" class AbstractStreamWriter(ABC): """Abstract stream writer.""" buffer_size: int = 0 output_size: int = 0 length: int | None = 0 @abstractmethod async def write( self, chunk: "bytes | bytearray | memoryview[int] | memoryview[bytes]" ) -> None: """Write chunk into stream.""" @abstractmethod async def write_eof(self, chunk: bytes = b"") -> None: """Write last chunk.""" @abstractmethod async def drain(self) -> None: """Flush the write buffer.""" @abstractmethod def enable_compression( self, encoding: str = "deflate", strategy: int | None = None ) -> None: """Enable HTTP body compression""" @abstractmethod def enable_chunking(self) -> None: """Enable HTTP chunked mode""" @abstractmethod async def write_headers(self, status_line: str, headers: CIMultiDict[str]) -> None: """Write HTTP headers""" def send_headers(self) -> None: """Force sending buffered headers if not already sent. Required only if write_headers() buffers headers instead of sending immediately. For backwards compatibility, this method does nothing by default. """ class AbstractAccessLogger(ABC): """Abstract writer to access log.""" __slots__ = ("logger", "log_format") def __init__(self, logger: logging.Logger, log_format: str) -> None: self.logger = logger self.log_format = log_format @abstractmethod def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None: """Emit log to logger.""" @property def enabled(self) -> bool: """Check if logger is enabled.""" return True class AbstractAsyncAccessLogger(ABC): """Abstract asynchronous writer to access log.""" __slots__ = () @abstractmethod async def log( self, request: BaseRequest, response: StreamResponse, request_start: float ) -> None: """Emit log to logger.""" @property def enabled(self) -> bool: """Check if logger is enabled.""" return True ================================================ FILE: aiohttp/base_protocol.py ================================================ import asyncio from typing import cast from .client_exceptions import ClientConnectionResetError from .helpers import set_exception from .tcp_helpers import tcp_nodelay class BaseProtocol(asyncio.Protocol): __slots__ = ( "_loop", "_paused", "_drain_waiter", "_connection_lost", "_reading_paused", "transport", ) def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop: asyncio.AbstractEventLoop = loop self._paused = False self._drain_waiter: asyncio.Future[None] | None = None self._reading_paused = False self.transport: asyncio.Transport | None = None @property def connected(self) -> bool: """Return True if the connection is open.""" return self.transport is not None @property def writing_paused(self) -> bool: return self._paused def pause_writing(self) -> None: assert not self._paused self._paused = True def resume_writing(self) -> None: assert self._paused self._paused = False waiter = self._drain_waiter if waiter is not None: self._drain_waiter = None if not waiter.done(): waiter.set_result(None) def pause_reading(self) -> None: if not self._reading_paused and self.transport is not None: try: self.transport.pause_reading() except (AttributeError, NotImplementedError, RuntimeError): pass self._reading_paused = True def resume_reading(self) -> None: if self._reading_paused and self.transport is not None: try: self.transport.resume_reading() except (AttributeError, NotImplementedError, RuntimeError): pass self._reading_paused = False def connection_made(self, transport: asyncio.BaseTransport) -> None: tr = cast(asyncio.Transport, transport) tcp_nodelay(tr, True) self.transport = tr def connection_lost(self, exc: BaseException | None) -> None: # Wake up the writer if currently paused. self.transport = None if not self._paused: return waiter = self._drain_waiter if waiter is None: return self._drain_waiter = None if waiter.done(): return if exc is None: waiter.set_result(None) else: set_exception( waiter, ConnectionError("Connection lost"), exc, ) async def _drain_helper(self) -> None: if self.transport is None: raise ClientConnectionResetError("Connection lost") if not self._paused: return waiter = self._drain_waiter if waiter is None: waiter = self._loop.create_future() self._drain_waiter = waiter await asyncio.shield(waiter) ================================================ FILE: aiohttp/client.py ================================================ """HTTP Client for asyncio.""" import asyncio import base64 import dataclasses import hashlib import json import os import sys import traceback import warnings from collections.abc import ( Awaitable, Callable, Collection, Coroutine, Generator, Iterable, Sequence, ) from contextlib import suppress from types import TracebackType from typing import ( TYPE_CHECKING, Any, Final, Generic, Literal, TypedDict, TypeVar, final, overload, ) from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr from yarl import URL, Query from . import hdrs, http, payload from ._websocket.reader import WebSocketDataQueue from .abc import AbstractCookieJar from .client_exceptions import ( ClientConnectionError, ClientConnectionResetError, ClientConnectorCertificateError, ClientConnectorDNSError, ClientConnectorError, ClientConnectorSSLError, ClientError, ClientHttpProxyError, ClientOSError, ClientPayloadError, ClientProxyConnectionError, ClientResponseError, ClientSSLError, ConnectionTimeoutError, ContentTypeError, InvalidURL, InvalidUrlClientError, InvalidUrlRedirectClientError, NonHttpUrlClientError, NonHttpUrlRedirectClientError, RedirectClientError, ServerConnectionError, ServerDisconnectedError, ServerFingerprintMismatch, ServerTimeoutError, SocketTimeoutError, TooManyRedirects, WSMessageTypeError, WSServerHandshakeError, ) from .client_middlewares import ClientMiddlewareType, build_client_middlewares from .client_reqrep import ( SSL_ALLOWED_TYPES, ClientRequest, ClientResponse, Fingerprint, RequestInfo, ) from .client_ws import ( DEFAULT_WS_CLIENT_TIMEOUT, ClientWebSocketResponse, ClientWSTimeout, ) from .connector import ( HTTP_AND_EMPTY_SCHEMA_SET, BaseConnector, NamedPipeConnector, TCPConnector, UnixConnector, ) from .cookiejar import CookieJar from .helpers import ( _SENTINEL, EMPTY_BODY_METHODS, BasicAuth, TimeoutHandle, basicauth_from_netrc, frozen_dataclass_decorator, get_env_proxy_for_url, netrc_from_env, sentinel, strip_auth_from_url, ) from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .tracing import Trace, TraceConfig from .typedefs import ( JSONBytesEncoder, JSONEncoder, LooseCookies, LooseHeaders, StrOrURL, ) __all__ = ( # client_exceptions "ClientConnectionError", "ClientConnectionResetError", "ClientConnectorCertificateError", "ClientConnectorDNSError", "ClientConnectorError", "ClientConnectorSSLError", "ClientError", "ClientHttpProxyError", "ClientOSError", "ClientPayloadError", "ClientProxyConnectionError", "ClientResponseError", "ClientSSLError", "ConnectionTimeoutError", "ContentTypeError", "InvalidURL", "InvalidUrlClientError", "RedirectClientError", "NonHttpUrlClientError", "InvalidUrlRedirectClientError", "NonHttpUrlRedirectClientError", "ServerConnectionError", "ServerDisconnectedError", "ServerFingerprintMismatch", "ServerTimeoutError", "SocketTimeoutError", "TooManyRedirects", "WSServerHandshakeError", # client_reqrep "ClientRequest", "ClientResponse", "Fingerprint", "RequestInfo", # connector "BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector", # client_ws "ClientWebSocketResponse", # client "ClientSession", "ClientTimeout", "ClientWSTimeout", "request", "WSMessageTypeError", ) if TYPE_CHECKING: from ssl import SSLContext else: SSLContext = None if sys.version_info >= (3, 11) and TYPE_CHECKING: from typing import Unpack class _RequestOptions(TypedDict, total=False): params: Query data: Any json: Any cookies: LooseCookies | None headers: LooseHeaders | None skip_auto_headers: Iterable[str] | None auth: BasicAuth | None allow_redirects: bool max_redirects: int compress: str | bool chunked: bool | None expect100: bool raise_for_status: None | bool | Callable[[ClientResponse], Awaitable[None]] read_until_eof: bool proxy: StrOrURL | None proxy_auth: BasicAuth | None timeout: "ClientTimeout | _SENTINEL | None" ssl: SSLContext | bool | Fingerprint server_hostname: str | None proxy_headers: LooseHeaders | None trace_request_ctx: object read_bufsize: int | None auto_decompress: bool | None max_line_size: int | None max_field_size: int | None max_headers: int | None middlewares: Sequence[ClientMiddlewareType] | None class _WSConnectOptions(TypedDict, total=False): method: str protocols: Collection[str] timeout: "ClientWSTimeout | _SENTINEL" receive_timeout: float | None autoclose: bool autoping: bool heartbeat: float | None auth: BasicAuth | None origin: str | None params: Query headers: LooseHeaders | None proxy: StrOrURL | None proxy_auth: BasicAuth | None ssl: SSLContext | bool | Fingerprint server_hostname: str | None proxy_headers: LooseHeaders | None compress: int max_msg_size: int @frozen_dataclass_decorator class ClientTimeout: total: float | None = None connect: float | None = None sock_read: float | None = None sock_connect: float | None = None ceil_threshold: float = 5 # pool_queue_timeout: Optional[float] = None # dns_resolution_timeout: Optional[float] = None # socket_connect_timeout: Optional[float] = None # connection_acquiring_timeout: Optional[float] = None # new_connection_timeout: Optional[float] = None # http_header_timeout: Optional[float] = None # response_body_timeout: Optional[float] = None # to create a timeout specific for a single request, either # - create a completely new one to overwrite the default # - or use https://docs.python.org/3/library/dataclasses.html#dataclasses.replace # to overwrite the defaults def __post_init__(self) -> None: if self.total is not None and self.total == 0: raise ValueError( "total timeout must be a positive number or None to disable, " "got 0. Using 0 to disable timeouts is no longer supported, " "use None instead." ) # 5 Minute default read timeout DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60, sock_connect=30) # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) _RetType_co = TypeVar( "_RetType_co", bound="ClientResponse | ClientWebSocketResponse[bool]", covariant=True, ) _CharsetResolver = Callable[[ClientResponse, bytes], str] @final class ClientSession: """First-class interface for making HTTP requests.""" __slots__ = ( "_base_url", "_base_url_origin", "_source_traceback", "_connector", "_loop", "_cookie_jar", "_connector_owner", "_default_auth", "_version", "_json_serialize", "_json_serialize_bytes", "_requote_redirect_url", "_timeout", "_raise_for_status", "_auto_decompress", "_trust_env", "_default_headers", "_skip_auto_headers", "_request_class", "_response_class", "_ws_response_class", "_trace_configs", "_read_bufsize", "_max_line_size", "_max_field_size", "_max_headers", "_resolve_charset", "_default_proxy", "_default_proxy_auth", "_retry_connection", "_middlewares", ) def __init__( self, base_url: StrOrURL | None = None, *, connector: BaseConnector | None = None, cookies: LooseCookies | None = None, headers: LooseHeaders | None = None, proxy: StrOrURL | None = None, proxy_auth: BasicAuth | None = None, skip_auto_headers: Iterable[str] | None = None, auth: BasicAuth | None = None, json_serialize: JSONEncoder = json.dumps, json_serialize_bytes: JSONBytesEncoder | None = None, request_class: type[ClientRequest] = ClientRequest, response_class: type[ClientResponse] = ClientResponse, ws_response_class: type[ClientWebSocketResponse] = ClientWebSocketResponse, version: HttpVersion = http.HttpVersion11, cookie_jar: AbstractCookieJar | None = None, connector_owner: bool = True, raise_for_status: bool | Callable[[ClientResponse], Awaitable[None]] = False, timeout: _SENTINEL | ClientTimeout | None = sentinel, auto_decompress: bool = True, trust_env: bool = False, requote_redirect_url: bool = True, trace_configs: list[TraceConfig[object]] | None = None, read_bufsize: int = 2**16, max_line_size: int = 8190, max_field_size: int = 8190, max_headers: int = 128, fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8", middlewares: Sequence[ClientMiddlewareType] = (), ssl_shutdown_timeout: _SENTINEL | None | float = sentinel, ) -> None: # We initialise _connector to None immediately, as it's referenced in __del__() # and could cause issues if an exception occurs during initialisation. self._connector: BaseConnector | None = None if base_url is None or isinstance(base_url, URL): self._base_url: URL | None = base_url self._base_url_origin = None if base_url is None else base_url.origin() else: self._base_url = URL(base_url) self._base_url_origin = self._base_url.origin() assert self._base_url.absolute, "Only absolute URLs are supported" if self._base_url is not None and not self._base_url.path.endswith("/"): raise ValueError("base_url must have a trailing '/'") loop = asyncio.get_running_loop() if timeout is sentinel or timeout is None: timeout = DEFAULT_TIMEOUT if not isinstance(timeout, ClientTimeout): raise ValueError( f"timeout parameter cannot be of {type(timeout)} type, " "please use 'timeout=ClientTimeout(...)'", ) self._timeout = timeout if ssl_shutdown_timeout is not sentinel: warnings.warn( "The ssl_shutdown_timeout parameter is deprecated and will be removed in aiohttp 4.0", DeprecationWarning, stacklevel=2, ) if connector is None: connector = TCPConnector(ssl_shutdown_timeout=ssl_shutdown_timeout) # Initialize these three attrs before raising any exception, # they are used in __del__ self._connector = connector self._loop = loop if loop.get_debug(): self._source_traceback: traceback.StackSummary | None = ( traceback.extract_stack(sys._getframe(1)) ) else: self._source_traceback = None if connector._loop is not loop: raise RuntimeError("Session and connector have to use same event loop") if cookie_jar is None: cookie_jar = CookieJar() self._cookie_jar = cookie_jar if cookies: self._cookie_jar.update_cookies(cookies) self._connector_owner = connector_owner self._default_auth = auth self._version = version self._json_serialize = json_serialize self._json_serialize_bytes = json_serialize_bytes self._raise_for_status = raise_for_status self._auto_decompress = auto_decompress self._trust_env = trust_env self._requote_redirect_url = requote_redirect_url self._read_bufsize = read_bufsize self._max_line_size = max_line_size self._max_field_size = max_field_size self._max_headers = max_headers # Convert to list of tuples if headers: real_headers: CIMultiDict[str] = CIMultiDict(headers) else: real_headers = CIMultiDict() self._default_headers: CIMultiDict[str] = real_headers if skip_auto_headers is not None: self._skip_auto_headers = frozenset(istr(i) for i in skip_auto_headers) else: self._skip_auto_headers = frozenset() self._request_class = request_class self._response_class = response_class self._ws_response_class = ws_response_class self._trace_configs = trace_configs or [] for trace_config in self._trace_configs: trace_config.freeze() self._resolve_charset = fallback_charset_resolver self._default_proxy = proxy self._default_proxy_auth = proxy_auth self._retry_connection: bool = True self._middlewares = middlewares def __init_subclass__(cls: type["ClientSession"]) -> None: raise TypeError( f"Inheritance class {cls.__name__} from ClientSession is forbidden" ) def __del__(self, _warnings: Any = warnings) -> None: if not self.closed: _warnings.warn( f"Unclosed client session {self!r}", ResourceWarning, source=self, ) context = {"client_session": self, "message": "Unclosed client session"} if self._source_traceback is not None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) if sys.version_info >= (3, 11) and TYPE_CHECKING: def request( self, method: str, url: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> "_RequestContextManager": ... else: def request( self, method: str, url: StrOrURL, **kwargs: Any ) -> "_RequestContextManager": """Perform HTTP request.""" return _RequestContextManager(self._request(method, url, **kwargs)) def _build_url(self, str_or_url: StrOrURL) -> URL: url = URL(str_or_url) if self._base_url and not url.absolute: return self._base_url.join(url) return url async def _request( self, method: str, str_or_url: StrOrURL, *, params: Query = None, data: Any = None, json: Any = None, cookies: LooseCookies | None = None, headers: LooseHeaders | None = None, skip_auto_headers: Iterable[str] | None = None, auth: BasicAuth | None = None, allow_redirects: bool = True, max_redirects: int = 10, compress: str | bool = False, chunked: bool | None = None, expect100: bool = False, raise_for_status: ( None | bool | Callable[[ClientResponse], Awaitable[None]] ) = None, read_until_eof: bool = True, proxy: StrOrURL | None = None, proxy_auth: BasicAuth | None = None, timeout: ClientTimeout | _SENTINEL | None = sentinel, ssl: SSLContext | bool | Fingerprint = True, server_hostname: str | None = None, proxy_headers: LooseHeaders | None = None, trace_request_ctx: object = None, read_bufsize: int | None = None, auto_decompress: bool | None = None, max_line_size: int | None = None, max_field_size: int | None = None, max_headers: int | None = None, middlewares: Sequence[ClientMiddlewareType] | None = None, ) -> ClientResponse: # NOTE: timeout clamps existing connect and read timeouts. We cannot # set the default to None because we need to detect if the user wants # to use the existing timeouts by setting timeout to None. if self.closed: raise RuntimeError("Session is closed") if not isinstance(ssl, SSL_ALLOWED_TYPES): raise TypeError( "ssl should be SSLContext, Fingerprint, or bool, " f"got {ssl!r} instead." ) if data is not None and json is not None: raise ValueError( "data and json parameters can not be used at the same time" ) elif json is not None: if self._json_serialize_bytes is not None: data = payload.JsonBytesPayload(json, dumps=self._json_serialize_bytes) else: data = payload.JsonPayload(json, dumps=self._json_serialize) redirects = 0 history: list[ClientResponse] = [] version = self._version params = params or {} # Merge with default headers and transform to CIMultiDict headers = self._prepare_headers(headers) try: url = self._build_url(str_or_url) except ValueError as e: raise InvalidUrlClientError(str_or_url) from e assert self._connector is not None if url.scheme not in self._connector.allowed_protocol_schema_set: raise NonHttpUrlClientError(url) skip_headers: Iterable[istr] | None if skip_auto_headers is not None: skip_headers = { istr(i) for i in skip_auto_headers } | self._skip_auto_headers elif self._skip_auto_headers: skip_headers = self._skip_auto_headers else: skip_headers = None if proxy is None: proxy = self._default_proxy if proxy_auth is None: proxy_auth = self._default_proxy_auth if proxy is None: proxy_headers = None else: proxy_headers = self._prepare_headers(proxy_headers) try: proxy = URL(proxy) except ValueError as e: raise InvalidURL(proxy) from e if timeout is sentinel or timeout is None: real_timeout: ClientTimeout = self._timeout else: real_timeout = timeout # timeout is cumulative for all request operations # (request, redirects, responses, data consuming) tm = TimeoutHandle( self._loop, real_timeout.total, ceil_threshold=real_timeout.ceil_threshold ) handle = tm.start() if read_bufsize is None: read_bufsize = self._read_bufsize if auto_decompress is None: auto_decompress = self._auto_decompress if max_line_size is None: max_line_size = self._max_line_size if max_field_size is None: max_field_size = self._max_field_size if max_headers is None: max_headers = self._max_headers traces = [ Trace( self, trace_config, trace_config.trace_config_ctx(trace_request_ctx=trace_request_ctx), ) for trace_config in self._trace_configs ] for trace in traces: await trace.send_request_start(method, url.update_query(params), headers) timer = tm.timer() try: with timer: # https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests retry_persistent_connection = ( self._retry_connection and method in IDEMPOTENT_METHODS ) while True: url, auth_from_url = strip_auth_from_url(url) if not url.raw_host: # NOTE: Bail early, otherwise, causes `InvalidURL` through # NOTE: `self._request_class()` below. err_exc_cls = ( InvalidUrlRedirectClientError if redirects else InvalidUrlClientError ) raise err_exc_cls(url) # If `auth` was passed for an already authenticated URL, # disallow only if this is the initial URL; this is to avoid issues # with sketchy redirects that are not the caller's responsibility if not history and (auth and auth_from_url): raise ValueError( "Cannot combine AUTH argument with " "credentials encoded in URL" ) # Override the auth with the one from the URL only if we # have no auth, or if we got an auth from a redirect URL if auth is None or (history and auth_from_url is not None): auth = auth_from_url if ( auth is None and self._default_auth and ( not self._base_url or self._base_url_origin == url.origin() ) ): auth = self._default_auth # Try netrc if auth is still None and trust_env is enabled. if auth is None and self._trust_env and url.host is not None: auth = await self._loop.run_in_executor( None, self._get_netrc_auth, url.host ) # It would be confusing if we support explicit # Authorization header with auth argument if auth is not None and hdrs.AUTHORIZATION in headers: raise ValueError( "Cannot combine AUTHORIZATION header " "with AUTH argument or credentials " "encoded in URL" ) all_cookies = self._cookie_jar.filter_cookies(url) if cookies is not None: tmp_cookie_jar = CookieJar( quote_cookie=self._cookie_jar.quote_cookie ) tmp_cookie_jar.update_cookies(cookies) req_cookies = tmp_cookie_jar.filter_cookies(url) if req_cookies: all_cookies.load(req_cookies) proxy_: URL | None = None if proxy is not None: proxy_ = URL(proxy) elif self._trust_env: with suppress(LookupError): proxy_, proxy_auth = await asyncio.to_thread( get_env_proxy_for_url, url ) req = self._request_class( method, url, params=params, headers=headers, skip_auto_headers=skip_headers, data=data, cookies=all_cookies, auth=auth, version=version, compress=compress, chunked=chunked, expect100=expect100, loop=self._loop, response_class=self._response_class, proxy=proxy_, proxy_auth=proxy_auth, timer=timer, session=self, ssl=ssl, server_hostname=server_hostname, proxy_headers=proxy_headers, traces=traces, trust_env=self.trust_env, ) async def _connect_and_send_request( req: ClientRequest, ) -> ClientResponse: # connection timeout assert self._connector is not None try: conn = await self._connector.connect( req, traces=traces, timeout=real_timeout ) except asyncio.TimeoutError as exc: raise ConnectionTimeoutError( f"Connection timeout to host {req.url}" ) from exc assert conn.protocol is not None conn.protocol.set_response_params( timer=timer, skip_payload=req.method in EMPTY_BODY_METHODS, read_until_eof=read_until_eof, auto_decompress=auto_decompress, read_timeout=real_timeout.sock_read, read_bufsize=read_bufsize, timeout_ceil_threshold=self._connector._timeout_ceil_threshold, max_line_size=max_line_size, max_field_size=max_field_size, max_headers=max_headers, ) try: resp = await req._send(conn) try: await resp.start(conn) except BaseException: resp.close() raise except BaseException: conn.close() raise return resp # Apply middleware (if any) - per-request middleware overrides session middleware effective_middlewares = ( self._middlewares if middlewares is None else middlewares ) if effective_middlewares: handler = build_client_middlewares( _connect_and_send_request, effective_middlewares ) else: handler = _connect_and_send_request try: resp = await handler(req) # Client connector errors should not be retried except ( ConnectionTimeoutError, ClientConnectorError, ClientConnectorCertificateError, ClientConnectorSSLError, ): raise except (ClientOSError, ServerDisconnectedError): if retry_persistent_connection: retry_persistent_connection = False continue raise except ClientError: raise except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): raise raise ClientOSError(*exc.args) from exc # Update cookies from raw headers to preserve duplicates if resp._raw_cookie_headers: self._cookie_jar.update_cookies_from_headers( resp._raw_cookie_headers, resp.url ) # redirects if resp.status in (301, 302, 303, 307, 308) and allow_redirects: for trace in traces: await trace.send_request_redirect( method, url.update_query(params), headers, resp ) redirects += 1 history.append(resp) if max_redirects and redirects >= max_redirects: if req._body is not None: await req._body.close() resp.close() raise TooManyRedirects( history[0].request_info, tuple(history) ) # For 301 and 302, mimic IE, now changed in RFC # https://github.com/kennethreitz/requests/pull/269 if (resp.status == 303 and resp.method != hdrs.METH_HEAD) or ( resp.status in (301, 302) and resp.method == hdrs.METH_POST ): method = hdrs.METH_GET data = None if headers.get(hdrs.CONTENT_LENGTH): headers.pop(hdrs.CONTENT_LENGTH) else: # For 307/308, always preserve the request body # For 301/302 with non-POST methods, preserve the request body # https://www.rfc-editor.org/rfc/rfc9110#section-15.4.3-3.1 # Use the existing payload to avoid recreating it from # a potentially consumed file. # # If the payload is already consumed and cannot be replayed, # fail fast instead of silently sending an empty body. if req._body.consumed: resp.close() raise ClientPayloadError( "Cannot follow redirect with a consumed request " "body. Use bytes, a seekable file-like object, " "or set allow_redirects=False." ) data = req._body r_url = resp.headers.get(hdrs.LOCATION) or resp.headers.get( hdrs.URI ) if r_url is None: # see github.com/aio-libs/aiohttp/issues/2022 break else: # reading from correct redirection # response is forbidden resp.release() try: parsed_redirect_url = URL( r_url, encoded=not self._requote_redirect_url ) except ValueError as e: if req._body is not None: await req._body.close() resp.close() raise InvalidUrlRedirectClientError( r_url, "Server attempted redirecting to a location that does not look like a URL", ) from e scheme = parsed_redirect_url.scheme if scheme not in HTTP_AND_EMPTY_SCHEMA_SET: if req._body is not None: await req._body.close() resp.close() raise NonHttpUrlRedirectClientError(r_url) elif not scheme: parsed_redirect_url = url.join(parsed_redirect_url) is_same_host_https_redirect = ( url.host == parsed_redirect_url.host and parsed_redirect_url.scheme == "https" and url.scheme == "http" ) try: redirect_origin = parsed_redirect_url.origin() except ValueError as origin_val_err: if req._body is not None: await req._body.close() resp.close() raise InvalidUrlRedirectClientError( parsed_redirect_url, "Invalid redirect URL origin", ) from origin_val_err if ( not is_same_host_https_redirect and url.origin() != redirect_origin ): auth = None headers.pop(hdrs.AUTHORIZATION, None) headers.pop(hdrs.COOKIE, None) headers.pop(hdrs.PROXY_AUTHORIZATION, None) url = parsed_redirect_url params = {} resp.release() continue break if req._body is not None: await req._body.close() # check response status if raise_for_status is None: raise_for_status = self._raise_for_status if raise_for_status is None: pass elif callable(raise_for_status): await raise_for_status(resp) elif raise_for_status: resp.raise_for_status() # register connection if handle is not None: if resp.connection is not None: resp.connection.add_callback(handle.cancel) else: handle.cancel() resp._history = tuple(history) for trace in traces: await trace.send_request_end( method, url.update_query(params), headers, resp ) return resp except BaseException as e: # cleanup timer tm.close() if handle: handle.cancel() handle = None for trace in traces: await trace.send_request_exception( method, url.update_query(params), headers, e ) raise if sys.version_info >= (3, 11) and TYPE_CHECKING: @overload def ws_connect( self, url: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Unpack[_WSConnectOptions], ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... @overload def ws_connect( self, url: StrOrURL, *, decode_text: Literal[False], **kwargs: Unpack[_WSConnectOptions], ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... @overload def ws_connect( self, url: StrOrURL, *, decode_text: bool = ..., **kwargs: Unpack[_WSConnectOptions], ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ... def ws_connect( self, url: StrOrURL, *, method: str = hdrs.METH_GET, protocols: Collection[str] = (), timeout: ClientWSTimeout | _SENTINEL = sentinel, receive_timeout: float | None = None, autoclose: bool = True, autoping: bool = True, heartbeat: float | None = None, auth: BasicAuth | None = None, origin: str | None = None, params: Query = None, headers: LooseHeaders | None = None, proxy: StrOrURL | None = None, proxy_auth: BasicAuth | None = None, ssl: SSLContext | bool | Fingerprint = True, server_hostname: str | None = None, proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, decode_text: bool = True, ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect( url, method=method, protocols=protocols, timeout=timeout, receive_timeout=receive_timeout, autoclose=autoclose, autoping=autoping, heartbeat=heartbeat, auth=auth, origin=origin, params=params, headers=headers, proxy=proxy, proxy_auth=proxy_auth, ssl=ssl, server_hostname=server_hostname, proxy_headers=proxy_headers, compress=compress, max_msg_size=max_msg_size, decode_text=decode_text, ) ) if sys.version_info >= (3, 11) and TYPE_CHECKING: @overload async def _ws_connect( self, url: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Unpack[_WSConnectOptions], ) -> "ClientWebSocketResponse[Literal[True]]": ... @overload async def _ws_connect( self, url: StrOrURL, *, decode_text: Literal[False], **kwargs: Unpack[_WSConnectOptions], ) -> "ClientWebSocketResponse[Literal[False]]": ... @overload async def _ws_connect( self, url: StrOrURL, *, decode_text: bool = ..., **kwargs: Unpack[_WSConnectOptions], ) -> "ClientWebSocketResponse[bool]": ... async def _ws_connect( self, url: StrOrURL, *, method: str = hdrs.METH_GET, protocols: Collection[str] = (), timeout: ClientWSTimeout | _SENTINEL = sentinel, receive_timeout: float | None = None, autoclose: bool = True, autoping: bool = True, heartbeat: float | None = None, auth: BasicAuth | None = None, origin: str | None = None, params: Query = None, headers: LooseHeaders | None = None, proxy: StrOrURL | None = None, proxy_auth: BasicAuth | None = None, ssl: SSLContext | bool | Fingerprint = True, server_hostname: str | None = None, proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, decode_text: bool = True, ) -> "ClientWebSocketResponse[bool]": if timeout is not sentinel: if isinstance(timeout, ClientWSTimeout): ws_timeout = timeout else: warnings.warn( # type: ignore[unreachable] "parameter 'timeout' of type 'float' " "is deprecated, please use " "'timeout=ClientWSTimeout(ws_close=...)'", DeprecationWarning, stacklevel=2, ) ws_timeout = ClientWSTimeout(ws_close=timeout) else: ws_timeout = DEFAULT_WS_CLIENT_TIMEOUT if receive_timeout is not None: warnings.warn( "float parameter 'receive_timeout' " "is deprecated, please use parameter " "'timeout=ClientWSTimeout(ws_receive=...)'", DeprecationWarning, stacklevel=2, ) ws_timeout = dataclasses.replace(ws_timeout, ws_receive=receive_timeout) if headers is None: real_headers: CIMultiDict[str] = CIMultiDict() else: real_headers = CIMultiDict(headers) default_headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "Upgrade", hdrs.SEC_WEBSOCKET_VERSION: "13", } for key, value in default_headers.items(): real_headers.setdefault(key, value) sec_key = base64.b64encode(os.urandom(16)) real_headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode() if protocols: real_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ",".join(protocols) if origin is not None: real_headers[hdrs.ORIGIN] = origin if compress: extstr = ws_ext_gen(compress=compress) real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr if not isinstance(ssl, SSL_ALLOWED_TYPES): raise TypeError( "ssl should be SSLContext, Fingerprint, or bool, " f"got {ssl!r} instead." ) # send request resp = await self.request( method, url, params=params, headers=real_headers, read_until_eof=False, auth=auth, proxy=proxy, proxy_auth=proxy_auth, ssl=ssl, server_hostname=server_hostname, proxy_headers=proxy_headers, ) try: # check handshake if resp.status != 101: raise WSServerHandshakeError( resp.request_info, resp.history, message="Invalid response status", status=resp.status, headers=resp.headers, ) if resp.headers.get(hdrs.UPGRADE, "").lower() != "websocket": raise WSServerHandshakeError( resp.request_info, resp.history, message="Invalid upgrade header", status=resp.status, headers=resp.headers, ) if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade": raise WSServerHandshakeError( resp.request_info, resp.history, message="Invalid connection header", status=resp.status, headers=resp.headers, ) # key calculation r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, "") match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode() if r_key != match: raise WSServerHandshakeError( resp.request_info, resp.history, message="Invalid challenge response", status=resp.status, headers=resp.headers, ) # websocket protocol protocol = None if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers: resp_protocols = [ proto.strip() for proto in resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") ] for proto in resp_protocols: if proto in protocols: protocol = proto break # websocket compress notakeover = False if compress: compress_hdrs = resp.headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) if compress_hdrs: try: compress, notakeover = ws_ext_parse(compress_hdrs) except WSHandshakeError as exc: raise WSServerHandshakeError( resp.request_info, resp.history, message=exc.args[0], status=resp.status, headers=resp.headers, ) from exc else: compress = 0 notakeover = False conn = resp.connection assert conn is not None conn_proto = conn.protocol assert conn_proto is not None # For WS connection the read_timeout must be either ws_timeout.ws_receive or greater # None == no timeout, i.e. infinite timeout, so None is the max timeout possible if ws_timeout.ws_receive is None: # Reset regardless conn_proto.read_timeout = None elif conn_proto.read_timeout is not None: conn_proto.read_timeout = max( ws_timeout.ws_receive, conn_proto.read_timeout ) transport = conn.transport assert transport is not None reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop) writer = WebSocketWriter( conn_proto, transport, use_mask=True, compress=compress, notakeover=notakeover, ) except BaseException: resp.close() raise else: ws_resp = self._ws_response_class( reader, writer, protocol, resp, ws_timeout, autoclose, autoping, self._loop, heartbeat=heartbeat, compress=compress, client_notakeover=notakeover, ) parser = WebSocketReader(reader, max_msg_size, decode_text=decode_text) cb = None if heartbeat is None else ws_resp._on_data_received conn_proto.set_parser(parser, reader, data_received_cb=cb) return ws_resp def _prepare_headers(self, headers: LooseHeaders | None) -> "CIMultiDict[str]": """Add default headers and transform it to CIMultiDict""" # Convert headers to MultiDict result = CIMultiDict(self._default_headers) if headers: if not isinstance(headers, (MultiDictProxy, MultiDict)): headers = CIMultiDict(headers) added_names: set[str] = set() for key, value in headers.items(): if key in added_names: result.add(key, value) else: result[key] = value added_names.add(key) return result def _get_netrc_auth(self, host: str) -> BasicAuth | None: """ Get auth from netrc for the given host. This method is designed to be called in an executor to avoid blocking I/O in the event loop. """ netrc_obj = netrc_from_env() try: return basicauth_from_netrc(netrc_obj, host) except LookupError: return None if sys.version_info >= (3, 11) and TYPE_CHECKING: def get( self, url: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> "_RequestContextManager": ... def options( self, url: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> "_RequestContextManager": ... def head( self, url: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> "_RequestContextManager": ... def post( self, url: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> "_RequestContextManager": ... def put( self, url: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> "_RequestContextManager": ... def patch( self, url: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> "_RequestContextManager": ... def delete( self, url: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> "_RequestContextManager": ... else: def get( self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any ) -> "_RequestContextManager": """Perform HTTP GET request.""" return _RequestContextManager( self._request( hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs ) ) def options( self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any ) -> "_RequestContextManager": """Perform HTTP OPTIONS request.""" return _RequestContextManager( self._request( hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs ) ) def head( self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any ) -> "_RequestContextManager": """Perform HTTP HEAD request.""" return _RequestContextManager( self._request( hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs ) ) def post( self, url: StrOrURL, *, data: Any = None, **kwargs: Any ) -> "_RequestContextManager": """Perform HTTP POST request.""" return _RequestContextManager( self._request(hdrs.METH_POST, url, data=data, **kwargs) ) def put( self, url: StrOrURL, *, data: Any = None, **kwargs: Any ) -> "_RequestContextManager": """Perform HTTP PUT request.""" return _RequestContextManager( self._request(hdrs.METH_PUT, url, data=data, **kwargs) ) def patch( self, url: StrOrURL, *, data: Any = None, **kwargs: Any ) -> "_RequestContextManager": """Perform HTTP PATCH request.""" return _RequestContextManager( self._request(hdrs.METH_PATCH, url, data=data, **kwargs) ) def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager": """Perform HTTP DELETE request.""" return _RequestContextManager( self._request(hdrs.METH_DELETE, url, **kwargs) ) async def close(self) -> None: """Close underlying connector. Release all acquired resources. """ if not self.closed: if self._connector is not None and self._connector_owner: await self._connector.close() self._connector = None @property def closed(self) -> bool: """Is client session closed. A readonly property. """ return self._connector is None or self._connector.closed @property def connector(self) -> BaseConnector | None: """Connector instance used for the session.""" return self._connector @property def cookie_jar(self) -> AbstractCookieJar: """The session cookies.""" return self._cookie_jar @property def version(self) -> tuple[int, int]: """The session HTTP protocol version.""" return self._version @property def requote_redirect_url(self) -> bool: """Do URL requoting on redirection handling.""" return self._requote_redirect_url @property def timeout(self) -> ClientTimeout: """Timeout for the session.""" return self._timeout @property def headers(self) -> "CIMultiDict[str]": """The default headers of the client session.""" return self._default_headers @property def skip_auto_headers(self) -> frozenset[istr]: """Headers for which autogeneration should be skipped""" return self._skip_auto_headers @property def auth(self) -> BasicAuth | None: """An object that represents HTTP Basic Authorization""" return self._default_auth @property def json_serialize(self) -> JSONEncoder: """Json serializer callable""" return self._json_serialize @property def connector_owner(self) -> bool: """Should connector be closed on session closing""" return self._connector_owner @property def raise_for_status( self, ) -> bool | Callable[[ClientResponse], Awaitable[None]]: """Should `ClientResponse.raise_for_status()` be called for each response.""" return self._raise_for_status @property def auto_decompress(self) -> bool: """Should the body response be automatically decompressed.""" return self._auto_decompress @property def trust_env(self) -> bool: """ Should proxies information from environment or netrc be trusted. Information is from HTTP_PROXY / HTTPS_PROXY environment variables or ~/.netrc file if present. """ return self._trust_env @property def trace_configs(self) -> list[TraceConfig[Any]]: """A list of TraceConfig instances used for client tracing""" return self._trace_configs def detach(self) -> None: """Detach connector from session without closing the former. Session is switched to closed state anyway. """ self._connector = None async def __aenter__(self) -> "ClientSession": return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: await self.close() class _BaseRequestContextManager( Coroutine[Any, Any, _RetType_co], Generic[_RetType_co] ): __slots__ = ("_coro", "_resp") def __init__(self, coro: Coroutine[asyncio.Future[Any], None, _RetType_co]) -> None: self._coro: Coroutine[asyncio.Future[Any], None, _RetType_co] = coro def send(self, arg: None) -> asyncio.Future[Any]: return self._coro.send(arg) def throw(self, *args: Any, **kwargs: Any) -> asyncio.Future[Any]: return self._coro.throw(*args, **kwargs) def close(self) -> None: return self._coro.close() def __await__(self) -> Generator[Any, None, _RetType_co]: ret = self._coro.__await__() return ret def __iter__(self) -> Generator[Any, None, _RetType_co]: return self.__await__() async def __aenter__(self) -> _RetType_co: self._resp: _RetType_co = await self._coro return await self._resp.__aenter__() # type: ignore[return-value] async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> None: await self._resp.__aexit__(exc_type, exc, tb) _RequestContextManager = _BaseRequestContextManager[ClientResponse] _WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse[bool]] class _SessionRequestContextManager: __slots__ = ("_coro", "_resp", "_session") def __init__( self, coro: Coroutine[asyncio.Future[Any], None, ClientResponse], session: ClientSession, ) -> None: self._coro = coro self._resp: ClientResponse | None = None self._session = session async def __aenter__(self) -> ClientResponse: try: self._resp = await self._coro except BaseException: await self._session.close() raise else: return self._resp async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> None: assert self._resp is not None self._resp.close() await self._session.close() if sys.version_info >= (3, 11) and TYPE_CHECKING: def request( method: str, url: StrOrURL, *, version: HttpVersion = http.HttpVersion11, connector: BaseConnector | None = None, **kwargs: Unpack[_RequestOptions], ) -> _SessionRequestContextManager: ... else: def request( method: str, url: StrOrURL, *, version: HttpVersion = http.HttpVersion11, connector: BaseConnector | None = None, **kwargs: Any, ) -> _SessionRequestContextManager: """Constructs and sends a request. Returns response object. method - HTTP method url - request url params - (optional) Dictionary or bytes to be sent in the query string of the new request data - (optional) Dictionary, bytes, or file-like object to send in the body of the request json - (optional) Any json compatible python object headers - (optional) Dictionary of HTTP Headers to send with the request cookies - (optional) Dict object to send with the request auth - (optional) BasicAuth named tuple represent HTTP Basic Auth auth - aiohttp.helpers.BasicAuth allow_redirects - (optional) If set to False, do not follow redirects version - Request HTTP version. compress - Set to True if request has to be compressed with deflate encoding. chunked - Set to chunk size for chunked transfer encoding. expect100 - Expect 100-continue response from server. connector - BaseConnector sub-class instance to support connection pooling. read_until_eof - Read response until eof if response does not have Content-Length header. loop - Optional event loop. timeout - Optional ClientTimeout settings structure, 5min total timeout by default. Usage:: >>> import aiohttp >>> async with aiohttp.request('GET', 'http://python.org/') as resp: ... print(resp) ... data = await resp.read() """ connector_owner = False if connector is None: connector_owner = True connector = TCPConnector(force_close=True) session = ClientSession( cookies=kwargs.pop("cookies", None), version=version, timeout=kwargs.pop("timeout", sentinel), connector=connector, connector_owner=connector_owner, ) return _SessionRequestContextManager( session._request(method, url, **kwargs), session, ) ================================================ FILE: aiohttp/client_exceptions.py ================================================ """HTTP related errors.""" import asyncio from typing import TYPE_CHECKING, Union from multidict import MultiMapping from .typedefs import StrOrURL try: import ssl SSLContext = ssl.SSLContext except ImportError: # pragma: no cover ssl = SSLContext = None # type: ignore[assignment] if TYPE_CHECKING: from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo from .http_parser import RawResponseMessage else: RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None __all__ = ( "ClientError", "ClientConnectionError", "ClientConnectionResetError", "ClientOSError", "ClientConnectorError", "ClientProxyConnectionError", "ClientSSLError", "ClientConnectorDNSError", "ClientConnectorSSLError", "ClientConnectorCertificateError", "ConnectionTimeoutError", "SocketTimeoutError", "ServerConnectionError", "ServerTimeoutError", "ServerDisconnectedError", "ServerFingerprintMismatch", "ClientResponseError", "ClientHttpProxyError", "WSServerHandshakeError", "ContentTypeError", "ClientPayloadError", "InvalidURL", "InvalidUrlClientError", "RedirectClientError", "NonHttpUrlClientError", "InvalidUrlRedirectClientError", "NonHttpUrlRedirectClientError", "WSMessageTypeError", ) class ClientError(Exception): """Base class for client connection errors.""" class ClientResponseError(ClientError): """Base class for exceptions that occur after getting a response. request_info: An instance of RequestInfo. history: A sequence of responses, if redirects occurred. status: HTTP status code. message: Error message. headers: Response headers. """ def __init__( self, request_info: RequestInfo, history: tuple[ClientResponse, ...], *, status: int | None = None, message: str = "", headers: MultiMapping[str] | None = None, ) -> None: self.request_info = request_info if status is not None: self.status = status else: self.status = 0 self.message = message self.headers = headers self.history = history self.args = (request_info, history) def __str__(self) -> str: return f"{self.status}, message={self.message!r}, url={str(self.request_info.real_url)!r}" def __repr__(self) -> str: args = f"{self.request_info!r}, {self.history!r}" if self.status != 0: args += f", status={self.status!r}" if self.message != "": args += f", message={self.message!r}" if self.headers is not None: args += f", headers={self.headers!r}" return f"{type(self).__name__}({args})" class ContentTypeError(ClientResponseError): """ContentType found is not valid.""" class WSServerHandshakeError(ClientResponseError): """websocket server handshake error.""" class ClientHttpProxyError(ClientResponseError): """HTTP proxy error. Raised in :class:`aiohttp.connector.TCPConnector` if proxy responds with status other than ``200 OK`` on ``CONNECT`` request. """ class TooManyRedirects(ClientResponseError): """Client was redirected too many times.""" class ClientConnectionError(ClientError): """Base class for client socket errors.""" class ClientConnectionResetError(ClientConnectionError, ConnectionResetError): """ConnectionResetError""" class ClientOSError(ClientConnectionError, OSError): """OSError error.""" class ClientConnectorError(ClientOSError): """Client connector error. Raised in :class:`aiohttp.connector.TCPConnector` if a connection can not be established. """ def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None: self._conn_key = connection_key self._os_error = os_error super().__init__(os_error.errno, os_error.strerror) self.args = (connection_key, os_error) @property def os_error(self) -> OSError: return self._os_error @property def host(self) -> str: return self._conn_key.host @property def port(self) -> int | None: return self._conn_key.port @property def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]: return self._conn_key.ssl def __str__(self) -> str: return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format( self, "default" if self.ssl is True else self.ssl, self.strerror ) # OSError.__reduce__ does too much black magick __reduce__ = BaseException.__reduce__ class ClientConnectorDNSError(ClientConnectorError): """DNS resolution failed during client connection. Raised in :class:`aiohttp.connector.TCPConnector` if DNS resolution fails. """ class ClientProxyConnectionError(ClientConnectorError): """Proxy connection error. Raised in :class:`aiohttp.connector.TCPConnector` if connection to proxy can not be established. """ class UnixClientConnectorError(ClientConnectorError): """Unix connector error. Raised in :py:class:`aiohttp.connector.UnixConnector` if connection to unix socket can not be established. """ def __init__( self, path: str, connection_key: ConnectionKey, os_error: OSError ) -> None: self._path = path super().__init__(connection_key, os_error) @property def path(self) -> str: return self._path def __str__(self) -> str: return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format( self, "default" if self.ssl is True else self.ssl, self.strerror ) class ServerConnectionError(ClientConnectionError): """Server connection errors.""" class ServerDisconnectedError(ServerConnectionError): """Server disconnected.""" def __init__(self, message: RawResponseMessage | str | None = None) -> None: if message is None: message = "Server disconnected" self.args = (message,) self.message = message class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError): """Server timeout error.""" class ConnectionTimeoutError(ServerTimeoutError): """Connection timeout error.""" class SocketTimeoutError(ServerTimeoutError): """Socket timeout error.""" class ServerFingerprintMismatch(ServerConnectionError): """SSL certificate does not match expected fingerprint.""" def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None: self.expected = expected self.got = got self.host = host self.port = port self.args = (expected, got, host, port) def __repr__(self) -> str: return f"<{self.__class__.__name__} expected={self.expected!r} got={self.got!r} host={self.host!r} port={self.port!r}>" class ClientPayloadError(ClientError): """Response payload error.""" class InvalidURL(ClientError, ValueError): """Invalid URL. URL used for fetching is malformed, e.g. it doesn't contains host part. """ # Derive from ValueError for backward compatibility def __init__(self, url: StrOrURL, description: str | None = None) -> None: # The type of url is not yarl.URL because the exception can be raised # on URL(url) call self._url = url self._description = description if description: super().__init__(url, description) else: super().__init__(url) @property def url(self) -> StrOrURL: return self._url @property def description(self) -> "str | None": return self._description def __repr__(self) -> str: return f"<{self.__class__.__name__} {self}>" def __str__(self) -> str: if self._description: return f"{self._url} - {self._description}" return str(self._url) class InvalidUrlClientError(InvalidURL): """Invalid URL client error.""" class RedirectClientError(ClientError): """Client redirect error.""" class NonHttpUrlClientError(ClientError): """Non http URL client error.""" class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError): """Invalid URL redirect client error.""" class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError): """Non http URL redirect client error.""" class ClientSSLError(ClientConnectorError): """Base error for ssl.*Errors.""" if ssl is not None: cert_errors = (ssl.CertificateError,) cert_errors_bases = ( ClientSSLError, ssl.CertificateError, ) ssl_errors = (ssl.SSLError,) ssl_error_bases = (ClientSSLError, ssl.SSLError) else: # pragma: no cover cert_errors = tuple() # type: ignore[unreachable] cert_errors_bases = ( ClientSSLError, ValueError, ) ssl_errors = tuple() ssl_error_bases = (ClientSSLError,) class ClientConnectorSSLError(*ssl_error_bases): # type: ignore[misc] """Response ssl error.""" class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore[misc] """Response certificate error.""" _conn_key: ConnectionKey def __init__( # TODO: If we require ssl in future, this can become ssl.CertificateError self, connection_key: ConnectionKey, certificate_error: Exception, ) -> None: if isinstance(certificate_error, cert_errors + (OSError,)): # ssl.CertificateError has errno and strerror, so we should be fine os_error = certificate_error else: os_error = OSError() super().__init__(connection_key, os_error) self._certificate_error = certificate_error self.args = (connection_key, certificate_error) @property def certificate_error(self) -> Exception: return self._certificate_error @property def host(self) -> str: return self._conn_key.host @property def port(self) -> int | None: return self._conn_key.port @property def ssl(self) -> bool: return self._conn_key.is_ssl def __str__(self) -> str: return ( f"Cannot connect to host {self.host}:{self.port} ssl:{self.ssl} " f"[{self.certificate_error.__class__.__name__}: " f"{self.certificate_error.args}]" ) class WSMessageTypeError(TypeError): """WebSocket message type is not valid.""" ================================================ FILE: aiohttp/client_middleware_digest_auth.py ================================================ """ Digest authentication middleware for aiohttp client. This middleware implements HTTP Digest Authentication according to RFC 7616, providing a more secure alternative to Basic Authentication. It supports all standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options. """ import hashlib import os import re import sys import time from collections.abc import Callable from typing import Final, Literal, TypedDict from yarl import URL from . import hdrs from .client_exceptions import ClientError from .client_middlewares import ClientHandlerType from .client_reqrep import ClientRequest, ClientResponse from .payload import Payload class DigestAuthChallenge(TypedDict, total=False): realm: str nonce: str qop: str algorithm: str opaque: str domain: str stale: str DigestFunctions: dict[str, Callable[[bytes], "hashlib._Hash"]] = { "MD5": hashlib.md5, "MD5-SESS": hashlib.md5, "SHA": hashlib.sha1, "SHA-SESS": hashlib.sha1, "SHA256": hashlib.sha256, "SHA256-SESS": hashlib.sha256, "SHA-256": hashlib.sha256, "SHA-256-SESS": hashlib.sha256, "SHA512": hashlib.sha512, "SHA512-SESS": hashlib.sha512, "SHA-512": hashlib.sha512, "SHA-512-SESS": hashlib.sha512, } # Compile the regex pattern once at module level for performance _HEADER_PAIRS_PATTERN = re.compile( r'(?:^|\s|,\s*)(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))' if sys.version_info < (3, 11) else r'(?:^|\s|,\s*)((?>\w+))\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))' # +------------|--------|--|-|-|--|----|------|----|--||-----|-> Match valid start/sep # +--------|--|-|-|--|----|------|----|--||-----|-> alphanumeric key (atomic # | | | | | | | | || | group reduces backtracking) # +--|-|-|--|----|------|----|--||-----|-> maybe whitespace # | | | | | | | || | # +-|-|--|----|------|----|--||-----|-> = (delimiter) # +-|--|----|------|----|--||-----|-> maybe whitespace # | | | | | || | # +--|----|------|----|--||-----|-> group quoted or unquoted # | | | | || | # +----|------|----|--||-----|-> if quoted... # +------|----|--||-----|-> anything but " or \ # +----|--||-----|-> escaped characters allowed # +--||-----|-> or can be empty string # || | # +|-----|-> if unquoted... # +-----|-> anything but , or # +-> at least one char req'd ) # RFC 7616: Challenge parameters to extract CHALLENGE_FIELDS: Final[ tuple[ Literal["realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale"], ... ] ] = ( "realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale", ) # Supported digest authentication algorithms # Use a tuple of sorted keys for predictable documentation and error messages SUPPORTED_ALGORITHMS: Final[tuple[str, ...]] = tuple(sorted(DigestFunctions.keys())) # RFC 7616: Fields that require quoting in the Digest auth header # These fields must be enclosed in double quotes in the Authorization header. # Algorithm, qop, and nc are never quoted per RFC specifications. # This frozen set is used by the template-based header construction to # automatically determine which fields need quotes. QUOTED_AUTH_FIELDS: Final[frozenset[str]] = frozenset( {"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"} ) def escape_quotes(value: str) -> str: """Escape double quotes for HTTP header values.""" return value.replace('"', '\\"') def unescape_quotes(value: str) -> str: """Unescape double quotes in HTTP header values.""" return value.replace('\\"', '"') def parse_header_pairs(header: str) -> dict[str, str]: """ Parse key-value pairs from WWW-Authenticate or similar HTTP headers. This function handles the complex format of WWW-Authenticate header values, supporting both quoted and unquoted values, proper handling of commas in quoted values, and whitespace variations per RFC 7616. Examples of supported formats: - key1="value1", key2=value2 - key1 = "value1" , key2="value, with, commas" - key1=value1,key2="value2" - realm="example.com", nonce="12345", qop="auth" Args: header: The header value string to parse Returns: Dictionary mapping parameter names to their values """ return { stripped_key: unescape_quotes(quoted_val) if quoted_val else unquoted_val for key, quoted_val, unquoted_val in _HEADER_PAIRS_PATTERN.findall(header) if (stripped_key := key.strip()) } class DigestAuthMiddleware: """ HTTP digest authentication middleware for aiohttp client. This middleware intercepts 401 Unauthorized responses containing a Digest authentication challenge, calculates the appropriate digest credentials, and automatically retries the request with the proper Authorization header. Features: - Handles all aspects of Digest authentication handshake automatically - Supports all standard hash algorithms: - MD5, MD5-SESS - SHA, SHA-SESS - SHA256, SHA256-SESS, SHA-256, SHA-256-SESS - SHA512, SHA512-SESS, SHA-512, SHA-512-SESS - Supports 'auth' and 'auth-int' quality of protection modes - Properly handles quoted strings and parameter parsing - Includes replay attack protection with client nonce count tracking - Supports preemptive authentication per RFC 7616 Section 3.6 Standards compliance: - RFC 7616: HTTP Digest Access Authentication (primary reference) - RFC 2617: HTTP Authentication (deprecated by RFC 7616) - RFC 1945: Section 11.1 (username restrictions) Implementation notes: The core digest calculation is inspired by the implementation in https://github.com/requests/requests/blob/v2.18.4/requests/auth.py with added support for modern digest auth features and error handling. """ def __init__( self, login: str, password: str, preemptive: bool = True, ) -> None: if login is None: raise ValueError("None is not allowed as login value") if password is None: raise ValueError("None is not allowed as password value") if ":" in login: raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)') self._login_str: Final[str] = login self._login_bytes: Final[bytes] = login.encode("utf-8") self._password_bytes: Final[bytes] = password.encode("utf-8") self._last_nonce_bytes = b"" self._nonce_count = 0 self._challenge: DigestAuthChallenge = {} self._preemptive: bool = preemptive # Set of URLs defining the protection space self._protection_space: list[str] = [] async def _encode(self, method: str, url: URL, body: Payload | Literal[b""]) -> str: """ Build digest authorization header for the current challenge. Args: method: The HTTP method (GET, POST, etc.) url: The request URL body: The request body (used for qop=auth-int) Returns: A fully formatted Digest authorization header string Raises: ClientError: If the challenge is missing required parameters or contains unsupported values """ challenge = self._challenge if "realm" not in challenge: raise ClientError( "Malformed Digest auth challenge: Missing 'realm' parameter" ) if "nonce" not in challenge: raise ClientError( "Malformed Digest auth challenge: Missing 'nonce' parameter" ) # Empty realm values are allowed per RFC 7616 (SHOULD, not MUST, contain host name) realm = challenge["realm"] nonce = challenge["nonce"] # Empty nonce values are not allowed as they are security-critical for replay protection if not nonce: raise ClientError( "Security issue: Digest auth challenge contains empty 'nonce' value" ) qop_raw = challenge.get("qop", "") # Preserve original algorithm case for response while using uppercase for processing algorithm_original = challenge.get("algorithm", "MD5") algorithm = algorithm_original.upper() opaque = challenge.get("opaque", "") # Convert string values to bytes once nonce_bytes = nonce.encode("utf-8") realm_bytes = realm.encode("utf-8") path = URL(url).path_qs # Process QoP qop = "" qop_bytes = b"" if qop_raw: valid_qops = {"auth", "auth-int"}.intersection( {q.strip() for q in qop_raw.split(",") if q.strip()} ) if not valid_qops: raise ClientError( f"Digest auth error: Unsupported Quality of Protection (qop) value(s): {qop_raw}" ) qop = "auth-int" if "auth-int" in valid_qops else "auth" qop_bytes = qop.encode("utf-8") if algorithm not in DigestFunctions: raise ClientError( f"Digest auth error: Unsupported hash algorithm: {algorithm}. " f"Supported algorithms: {', '.join(SUPPORTED_ALGORITHMS)}" ) hash_fn: Final = DigestFunctions[algorithm] def H(x: bytes) -> bytes: """RFC 7616 Section 3: Hash function H(data) = hex(hash(data)).""" return hash_fn(x).hexdigest().encode() def KD(s: bytes, d: bytes) -> bytes: """RFC 7616 Section 3: KD(secret, data) = H(concat(secret, ":", data)).""" return H(b":".join((s, d))) # Calculate A1 and A2 A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes)) A2 = f"{method.upper()}:{path}".encode() if qop == "auth-int": if isinstance(body, Payload): # will always be empty bytes unless Payload entity_bytes = await body.as_bytes() # Get bytes from Payload else: entity_bytes = body entity_hash = H(entity_bytes) A2 = b":".join((A2, entity_hash)) HA1 = H(A1) HA2 = H(A2) # Nonce count handling if nonce_bytes == self._last_nonce_bytes: self._nonce_count += 1 else: self._nonce_count = 1 self._last_nonce_bytes = nonce_bytes ncvalue = f"{self._nonce_count:08x}" ncvalue_bytes = ncvalue.encode("utf-8") # Generate client nonce cnonce = hashlib.sha1( b"".join( [ str(self._nonce_count).encode("utf-8"), nonce_bytes, time.ctime().encode("utf-8"), os.urandom(8), ] ) ).hexdigest()[:16] cnonce_bytes = cnonce.encode("utf-8") # Special handling for session-based algorithms if algorithm.upper().endswith("-SESS"): HA1 = H(b":".join((HA1, nonce_bytes, cnonce_bytes))) # Calculate the response digest if qop: noncebit = b":".join( (nonce_bytes, ncvalue_bytes, cnonce_bytes, qop_bytes, HA2) ) response_digest = KD(HA1, noncebit) else: response_digest = KD(HA1, b":".join((nonce_bytes, HA2))) # Define a dict mapping of header fields to their values # Group fields into always-present, optional, and qop-dependent header_fields = { # Always present fields "username": escape_quotes(self._login_str), "realm": escape_quotes(realm), "nonce": escape_quotes(nonce), "uri": path, "response": response_digest.decode(), "algorithm": algorithm_original, } # Optional fields if opaque: header_fields["opaque"] = escape_quotes(opaque) # QoP-dependent fields if qop: header_fields["qop"] = qop header_fields["nc"] = ncvalue header_fields["cnonce"] = cnonce # Build header using templates for each field type pairs: list[str] = [] for field, value in header_fields.items(): if field in QUOTED_AUTH_FIELDS: pairs.append(f'{field}="{value}"') else: pairs.append(f"{field}={value}") return f"Digest {', '.join(pairs)}" def _in_protection_space(self, url: URL) -> bool: """ Check if the given URL is within the current protection space. According to RFC 7616, a URI is in the protection space if any URI in the protection space is a prefix of it (after both have been made absolute). """ request_str = str(url) for space_str in self._protection_space: # Check if request starts with space URL if not request_str.startswith(space_str): continue # Exact match or space ends with / (proper directory prefix) if len(request_str) == len(space_str) or space_str[-1] == "/": return True # Check next char is / to ensure proper path boundary if request_str[len(space_str)] == "/": return True return False def _authenticate(self, response: ClientResponse) -> bool: """ Takes the given response and tries digest-auth, if needed. Returns true if the original request must be resent. """ if response.status != 401: return False auth_header = response.headers.get("www-authenticate", "") if not auth_header: return False # No authentication header present method, sep, headers = auth_header.partition(" ") if not sep: # No space found in www-authenticate header return False # Malformed auth header, missing scheme separator if method.lower() != "digest": # Not a digest auth challenge (could be Basic, Bearer, etc.) return False if not headers: # We have a digest scheme but no parameters return False # Malformed digest header, missing parameters # We have a digest auth header with content if not (header_pairs := parse_header_pairs(headers)): # Failed to parse any key-value pairs return False # Malformed digest header, no valid parameters # Extract challenge parameters self._challenge = {} for field in CHALLENGE_FIELDS: if (value := header_pairs.get(field)) is not None: self._challenge[field] = value # Update protection space based on domain parameter or default to origin origin = response.url.origin() if domain := self._challenge.get("domain"): # Parse space-separated list of URIs self._protection_space = [] for uri in domain.split(): # Remove quotes if present uri = uri.strip('"') if uri.startswith("/"): # Path-absolute, relative to origin self._protection_space.append(str(origin.join(URL(uri)))) else: # Absolute URI self._protection_space.append(str(URL(uri))) else: # No domain specified, protection space is entire origin self._protection_space = [str(origin)] # Return True only if we found at least one challenge parameter return bool(self._challenge) async def __call__( self, request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: """Run the digest auth middleware.""" response = None for retry_count in range(2): # Apply authorization header if: # 1. This is a retry after 401 (retry_count > 0), OR # 2. Preemptive auth is enabled AND we have a challenge AND the URL is in protection space if retry_count > 0 or ( self._preemptive and self._challenge and self._in_protection_space(request.url) ): request.headers[hdrs.AUTHORIZATION] = await self._encode( request.method, request.url, request.body ) # Send the request response = await handler(request) # Check if we need to authenticate if not self._authenticate(response): break # At this point, response is guaranteed to be defined assert response is not None return response ================================================ FILE: aiohttp/client_middlewares.py ================================================ """Client middleware support.""" from collections.abc import Awaitable, Callable, Sequence from .client_reqrep import ClientRequest, ClientResponse __all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares") # Type alias for client request handlers - functions that process requests and return responses ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]] # Type for client middleware - similar to server but uses ClientRequest/ClientResponse ClientMiddlewareType = Callable[ [ClientRequest, ClientHandlerType], Awaitable[ClientResponse] ] def build_client_middlewares( handler: ClientHandlerType, middlewares: Sequence[ClientMiddlewareType], ) -> ClientHandlerType: """ Apply middlewares to request handler. The middlewares are applied in reverse order, so the first middleware in the list wraps all subsequent middlewares and the handler. This implementation avoids using partial/update_wrapper to minimize overhead and doesn't cache to avoid holding references to stateful middleware. """ # Optimize for single middleware case if len(middlewares) == 1: middleware = middlewares[0] async def single_middleware_handler(req: ClientRequest) -> ClientResponse: return await middleware(req, handler) return single_middleware_handler # Build the chain for multiple middlewares current_handler = handler for middleware in reversed(middlewares): # Create a new closure that captures the current state def make_wrapper( mw: ClientMiddlewareType, next_h: ClientHandlerType ) -> ClientHandlerType: async def wrapped(req: ClientRequest) -> ClientResponse: return await mw(req, next_h) return wrapped current_handler = make_wrapper(middleware, current_handler) return current_handler ================================================ FILE: aiohttp/client_proto.py ================================================ import asyncio from contextlib import suppress from typing import Callable, Protocol from ._websocket.reader import WebSocketDataQueue from .base_protocol import BaseProtocol from .client_exceptions import ( ClientConnectionError, ClientOSError, ClientPayloadError, ServerDisconnectedError, SocketTimeoutError, ) from .helpers import ( _EXC_SENTINEL, EMPTY_BODY_STATUS_CODES, BaseTimerContext, ErrorableProtocol, set_exception, set_result, ) from .http import HttpResponseParser, RawResponseMessage, WebSocketReader from .http_exceptions import HttpProcessingError from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader class _Payload(ErrorableProtocol, Protocol): def is_eof(self) -> bool: ... class ResponseHandler(BaseProtocol, DataQueue[tuple[RawResponseMessage, StreamReader]]): """Helper class to adapt between Protocol and StreamReader.""" def __init__(self, loop: asyncio.AbstractEventLoop) -> None: BaseProtocol.__init__(self, loop=loop) DataQueue.__init__(self, loop) self._should_close = False self._payload: _Payload | None = None self._skip_payload = False self._payload_parser: WebSocketReader | None = None self._data_received_cb: Callable[[], None] | None = None self._timer = None self._tail = b"" self._upgraded = False self._parser: HttpResponseParser | None = None self._read_timeout: float | None = None self._read_timeout_handle: asyncio.TimerHandle | None = None self._timeout_ceil_threshold: float | None = 5 self._closed: None | asyncio.Future[None] = None self._connection_lost_called = False @property def closed(self) -> None | asyncio.Future[None]: """Future that is set when the connection is closed. This property returns a Future that will be completed when the connection is closed. The Future is created lazily on first access to avoid creating futures that will never be awaited. Returns: - A Future[None] if the connection is still open or was closed after this property was accessed - None if connection_lost() was already called before this property was ever accessed (indicating no one is waiting for the closure) """ if self._closed is None and not self._connection_lost_called: self._closed = self._loop.create_future() return self._closed @property def upgraded(self) -> bool: return self._upgraded @property def should_close(self) -> bool: return bool( self._should_close or (self._payload is not None and not self._payload.is_eof()) or self._upgraded or self._exception is not None or self._payload_parser is not None or self._buffer or self._tail ) def force_close(self) -> None: self._should_close = True def close(self) -> None: self._exception = None # Break cyclic references transport = self.transport if transport is not None: transport.close() self.transport = None self._payload = None self._drop_timeout() def abort(self) -> None: self._exception = None # Break cyclic references transport = self.transport if transport is not None: transport.abort() self.transport = None self._payload = None self._drop_timeout() def is_connected(self) -> bool: return self.transport is not None and not self.transport.is_closing() def connection_lost(self, exc: BaseException | None) -> None: self._connection_lost_called = True self._drop_timeout() original_connection_error = exc reraised_exc = original_connection_error connection_closed_cleanly = original_connection_error is None if self._closed is not None: # If someone is waiting for the closed future, # we should set it to None or an exception. If # self._closed is None, it means that # connection_lost() was called already # or nobody is waiting for it. if connection_closed_cleanly: set_result(self._closed, None) else: assert original_connection_error is not None set_exception( self._closed, ClientConnectionError( f"Connection lost: {original_connection_error !s}", ), original_connection_error, ) if self._payload_parser is not None: with suppress(Exception): # FIXME: log this somehow? self._payload_parser.feed_eof() uncompleted = None if self._parser is not None: try: uncompleted = self._parser.feed_eof() except Exception as underlying_exc: if self._payload is not None: client_payload_exc_msg = ( f"Response payload is not completed: {underlying_exc !r}" ) if not connection_closed_cleanly: client_payload_exc_msg = ( f"{client_payload_exc_msg !s}. " f"{original_connection_error !r}" ) set_exception( self._payload, ClientPayloadError(client_payload_exc_msg), underlying_exc, ) if not self.is_eof(): if isinstance(original_connection_error, OSError): reraised_exc = ClientOSError(*original_connection_error.args) if connection_closed_cleanly: reraised_exc = ServerDisconnectedError(uncompleted) # assigns self._should_close to True as side effect, # we do it anyway below underlying_non_eof_exc = ( _EXC_SENTINEL if connection_closed_cleanly else original_connection_error ) assert underlying_non_eof_exc is not None assert reraised_exc is not None self.set_exception(reraised_exc, underlying_non_eof_exc) self._should_close = True self._parser = None self._payload = None self._payload_parser = None self._reading_paused = False super().connection_lost(reraised_exc) def eof_received(self) -> None: # should call parser.feed_eof() most likely self._drop_timeout() def pause_reading(self) -> None: super().pause_reading() self._drop_timeout() def resume_reading(self) -> None: super().resume_reading() self._reschedule_timeout() def set_exception( self, exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: self._should_close = True self._drop_timeout() super().set_exception(exc, exc_cause) def set_parser( self, parser: WebSocketReader, payload: WebSocketDataQueue, data_received_cb: Callable[[], None] | None = None, ) -> None: self._payload = payload self._payload_parser = parser self._data_received_cb = data_received_cb self._drop_timeout() if self._tail: data, self._tail = self._tail, b"" self.data_received(data) def set_response_params( self, *, timer: BaseTimerContext | None = None, skip_payload: bool = False, read_until_eof: bool = False, auto_decompress: bool = True, read_timeout: float | None = None, read_bufsize: int = 2**16, timeout_ceil_threshold: float = 5, max_line_size: int = 8190, max_field_size: int = 8190, max_headers: int = 128, ) -> None: self._skip_payload = skip_payload self._read_timeout = read_timeout self._timeout_ceil_threshold = timeout_ceil_threshold self._parser = HttpResponseParser( self, self._loop, read_bufsize, timer=timer, payload_exception=ClientPayloadError, response_with_body=not skip_payload, read_until_eof=read_until_eof, auto_decompress=auto_decompress, max_line_size=max_line_size, max_field_size=max_field_size, max_headers=max_headers, ) if self._tail: data, self._tail = self._tail, b"" self.data_received(data) def _drop_timeout(self) -> None: if self._read_timeout_handle is not None: self._read_timeout_handle.cancel() self._read_timeout_handle = None def _reschedule_timeout(self) -> None: timeout = self._read_timeout if self._read_timeout_handle is not None: self._read_timeout_handle.cancel() if timeout: self._read_timeout_handle = self._loop.call_later( timeout, self._on_read_timeout ) else: self._read_timeout_handle = None def start_timeout(self) -> None: self._reschedule_timeout() @property def read_timeout(self) -> float | None: return self._read_timeout @read_timeout.setter def read_timeout(self, read_timeout: float | None) -> None: self._read_timeout = read_timeout def _on_read_timeout(self) -> None: exc = SocketTimeoutError("Timeout on reading data from socket") self.set_exception(exc) if self._payload is not None: set_exception(self._payload, exc) def data_received(self, data: bytes) -> None: self._reschedule_timeout() if not data: return # custom payload parser - currently always WebSocketReader if self._payload_parser is not None: if self._data_received_cb is not None: self._data_received_cb() eof, tail = self._payload_parser.feed_data(data) if eof: self._payload = None self._payload_parser = None if tail: self.data_received(tail) return if self._upgraded or self._parser is None: # i.e. websocket connection, websocket parser is not set yet self._tail += data return # parse http messages try: messages, upgraded, tail = self._parser.feed_data(data) except Exception as underlying_exc: if self.transport is not None: # connection.release() could be called BEFORE # data_received(), the transport is already # closed in this case self.transport.close() # should_close is True after the call if isinstance(underlying_exc, HttpProcessingError): exc = HttpProcessingError( code=underlying_exc.code, message=underlying_exc.message, headers=underlying_exc.headers, ) else: exc = HttpProcessingError() self.set_exception(exc, underlying_exc) return self._upgraded = upgraded payload: StreamReader | None = None for message, payload in messages: if message.should_close: self._should_close = True self._payload = payload if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES: self.feed_data((message, EMPTY_PAYLOAD)) else: self.feed_data((message, payload)) if payload is not None: # new message(s) was processed # register timeout handler unsubscribing # either on end-of-stream or immediately for # EMPTY_PAYLOAD if payload is not EMPTY_PAYLOAD: payload.on_eof(self._drop_timeout) else: self._drop_timeout() if upgraded and tail: self.data_received(tail) ================================================ FILE: aiohttp/client_reqrep.py ================================================ import asyncio import codecs import contextlib import functools import io import re import sys import traceback import warnings from collections.abc import Callable, Iterable, Sequence from hashlib import md5, sha1, sha256 from http.cookies import BaseCookie, SimpleCookie from types import MappingProxyType, TracebackType from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy from yarl import URL, Query from . import hdrs, multipart, payload from ._cookie_helpers import ( parse_cookie_header, parse_set_cookie_headers, preserve_morsel_with_coded_value, ) from .abc import AbstractStreamWriter from .base_protocol import BaseProtocol from .client_exceptions import ( ClientConnectionError, ClientOSError, ClientResponseError, ContentTypeError, InvalidURL, ServerFingerprintMismatch, ) from .compression_utils import HAS_BROTLI, HAS_ZSTD from .formdata import FormData from .helpers import ( _SENTINEL, BaseTimerContext, BasicAuth, HeadersMixin, TimerNoop, frozen_dataclass_decorator, is_expected_content_type, parse_mimetype, reify, sentinel, set_exception, set_result, ) from .http import ( SERVER_SOFTWARE, HttpProcessingError, HttpVersion, HttpVersion10, HttpVersion11, StreamWriter, ) from .streams import StreamReader from .typedefs import DEFAULT_JSON_DECODER, JSONDecoder, RawHeaders try: import ssl from ssl import SSLContext except ImportError: # pragma: no cover ssl = None # type: ignore[assignment] SSLContext = object # type: ignore[misc,assignment] __all__ = ("ClientRequest", "ClientResponse", "RequestInfo", "Fingerprint") if TYPE_CHECKING: from .client import ClientSession from .connector import Connection from .tracing import Trace _CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed") _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]") def _gen_default_accept_encoding() -> str: encodings = [ "gzip", "deflate", ] if HAS_BROTLI: encodings.append("br") if HAS_ZSTD: encodings.append("zstd") return ", ".join(encodings) @frozen_dataclass_decorator class ContentDisposition: type: str | None parameters: "MappingProxyType[str, str]" filename: str | None class _RequestInfo(NamedTuple): url: URL method: str headers: "CIMultiDictProxy[str]" real_url: URL class RequestInfo(_RequestInfo): def __new__( cls, url: URL, method: str, headers: "CIMultiDictProxy[str]", real_url: URL | _SENTINEL = sentinel, ) -> "RequestInfo": """Create a new RequestInfo instance. For backwards compatibility, the real_url parameter is optional. """ return tuple.__new__( cls, (url, method, headers, url if real_url is sentinel else real_url) ) class Fingerprint: HASHFUNC_BY_DIGESTLEN = { 16: md5, 20: sha1, 32: sha256, } def __init__(self, fingerprint: bytes) -> None: digestlen = len(fingerprint) hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen) if not hashfunc: raise ValueError("fingerprint has invalid length") elif hashfunc is md5 or hashfunc is sha1: raise ValueError("md5 and sha1 are insecure and not supported. Use sha256.") self._hashfunc = hashfunc self._fingerprint = fingerprint @property def fingerprint(self) -> bytes: return self._fingerprint def check(self, transport: asyncio.Transport) -> None: if not transport.get_extra_info("sslcontext"): return sslobj = transport.get_extra_info("ssl_object") cert = sslobj.getpeercert(binary_form=True) got = self._hashfunc(cert).digest() if got != self._fingerprint: host, port, *_ = transport.get_extra_info("peername") raise ServerFingerprintMismatch(self._fingerprint, got, host, port) if ssl is not None: SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint) else: # pragma: no cover SSL_ALLOWED_TYPES = (bool,) # type: ignore[unreachable] _CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed") _SSL_SCHEMES = frozenset(("https", "wss")) # ConnectionKey is a NamedTuple because it is used as a key in a dict # and a set in the connector. Since a NamedTuple is a tuple it uses # the fast native tuple __hash__ and __eq__ implementation in CPython. class ConnectionKey(NamedTuple): # the key should contain an information about used proxy / TLS # to prevent reusing wrong connections from a pool host: str port: int | None is_ssl: bool ssl: SSLContext | bool | Fingerprint proxy: URL | None proxy_auth: BasicAuth | None proxy_headers_hash: int | None # hash(CIMultiDict) class ClientResponse(HeadersMixin): # Some of these attributes are None when created, # but will be set by the start() method. # As the end user will likely never see the None values, we cheat the types below. # from the Status-Line of the response version: HttpVersion | None = None # HTTP-Version status: int = None # type: ignore[assignment] # Status-Code reason: str | None = None # Reason-Phrase content: StreamReader = None # type: ignore[assignment] # Payload stream _body: bytes | None = None _headers: CIMultiDictProxy[str] = None # type: ignore[assignment] _history: tuple["ClientResponse", ...] = () _raw_headers: RawHeaders = None # type: ignore[assignment] _connection: "Connection | None" = None # current connection _cookies: SimpleCookie | None = None _raw_cookie_headers: tuple[str, ...] | None = None _continue: asyncio.Future[bool] | None = None _source_traceback: traceback.StackSummary | None = None _session: "ClientSession | None" = None # set up by ClientRequest after ClientResponse object creation # post-init stage allows to not change ctor signature _closed = True # to allow __del__ for non-initialized properly response _released = False _in_context = False _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8" __writer: asyncio.Task[None] | None = None def __init__( self, method: str, url: URL, *, writer: asyncio.Task[None] | None, continue100: asyncio.Future[bool] | None, timer: BaseTimerContext | None, traces: Sequence["Trace"], loop: asyncio.AbstractEventLoop, session: "ClientSession | None", request_headers: CIMultiDict[str], original_url: URL, **kwargs: object, ) -> None: # kwargs exists so authors of subclasses should expect to pass through unknown # arguments. This allows us to safely add new arguments in future releases. # But, we should never receive unknown arguments here in the parent class, this # would indicate an argument has been named wrong or similar in the subclass. assert not kwargs, "Unexpected arguments to ClientResponse" # URL forbids subclasses, so a simple type check is enough. assert type(url) is URL self.method = method self._real_url = url self._url = url.with_fragment(None) if url.raw_fragment else url if writer is not None: self._writer = writer if continue100 is not None: self._continue = continue100 self._request_headers = request_headers self._original_url = original_url self._timer = timer if timer is not None else TimerNoop() self._cache: dict[str, Any] = {} self._traces = traces self._loop = loop # Save reference to _resolve_charset, so that get_encoding() will still # work after the response has finished reading the body. if session is not None: # store a reference to session #1985 self._session = session self._resolve_charset = session._resolve_charset if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) def __reset_writer(self, _: object = None) -> None: self.__writer = None @property def _writer(self) -> asyncio.Task[None] | None: """The writer task for streaming data. _writer is only provided for backwards compatibility for subclasses that may need to access it. """ return self.__writer @_writer.setter def _writer(self, writer: asyncio.Task[None] | None) -> None: """Set the writer task for streaming data.""" if self.__writer is not None: self.__writer.remove_done_callback(self.__reset_writer) self.__writer = writer if writer is None: return if writer.done(): # The writer is already done, so we can clear it immediately. self.__writer = None else: writer.add_done_callback(self.__reset_writer) @property def cookies(self) -> SimpleCookie: if self._cookies is None: if self._raw_cookie_headers is not None: # Parse cookies for response.cookies (SimpleCookie for backward compatibility) cookies = SimpleCookie() # Use parse_set_cookie_headers for more lenient parsing that handles # malformed cookies better than SimpleCookie.load cookies.update(parse_set_cookie_headers(self._raw_cookie_headers)) self._cookies = cookies else: self._cookies = SimpleCookie() return self._cookies @cookies.setter def cookies(self, cookies: SimpleCookie) -> None: self._cookies = cookies # Generate raw cookie headers from the SimpleCookie if cookies: self._raw_cookie_headers = tuple( morsel.OutputString() for morsel in cookies.values() ) else: self._raw_cookie_headers = None @reify def url(self) -> URL: return self._url @reify def real_url(self) -> URL: return self._real_url @reify def host(self) -> str: assert self._url.host is not None return self._url.host @reify def headers(self) -> "CIMultiDictProxy[str]": return self._headers @reify def raw_headers(self) -> RawHeaders: return self._raw_headers @reify def request_info(self) -> RequestInfo: # Build RequestInfo lazily from components headers = CIMultiDictProxy(self._request_headers) return tuple.__new__( RequestInfo, (self._url, self.method, headers, self._original_url) ) @reify def content_disposition(self) -> ContentDisposition | None: raw = self._headers.get(hdrs.CONTENT_DISPOSITION) if raw is None: return None disposition_type, params_dct = multipart.parse_content_disposition(raw) params = MappingProxyType(params_dct) filename = multipart.content_disposition_filename(params) return ContentDisposition(disposition_type, params, filename) def __del__(self, _warnings: Any = warnings) -> None: if self._closed: return if self._connection is not None: self._connection.release() self._cleanup_writer() if self._loop.get_debug(): _warnings.warn( f"Unclosed response {self!r}", ResourceWarning, source=self ) context = {"client_response": self, "message": "Unclosed response"} if self._source_traceback: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) def __repr__(self) -> str: out = io.StringIO() ascii_encodable_url = str(self.url) if self.reason: ascii_encodable_reason = self.reason.encode( "ascii", "backslashreplace" ).decode("ascii") else: ascii_encodable_reason = "None" print( f"", file=out, ) print(self.headers, file=out) return out.getvalue() @property def connection(self) -> "Connection | None": return self._connection @reify def history(self) -> tuple["ClientResponse", ...]: """A sequence of responses, if redirects occurred.""" return self._history @reify def links(self) -> "MultiDictProxy[MultiDictProxy[str | URL]]": links_str = ", ".join(self.headers.getall("link", [])) if not links_str: return MultiDictProxy(MultiDict()) links: MultiDict[MultiDictProxy[str | URL]] = MultiDict() for val in re.split(r",(?=\s*<)", links_str): match = re.match(r"\s*<(.*)>(.*)", val) if match is None: # Malformed link continue url, params_str = match.groups() params = params_str.split(";")[1:] link: MultiDict[str | URL] = MultiDict() for param in params: match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M) if match is None: # Malformed param continue key, _, value, _ = match.groups() link.add(key, value) key = link.get("rel", url) link.add("url", self.url.join(URL(url))) links.add(str(key), MultiDictProxy(link)) return MultiDictProxy(links) async def start(self, connection: "Connection") -> "ClientResponse": """Start response processing.""" self._closed = False self._protocol = connection.protocol self._connection = connection with self._timer: while True: # read response try: protocol = self._protocol message, payload = await protocol.read() # type: ignore[union-attr] except HttpProcessingError as exc: raise ClientResponseError( self.request_info, self.history, status=exc.code, message=exc.message, headers=exc.headers, ) from exc if message.code < 100 or message.code > 199 or message.code == 101: break if self._continue is not None: set_result(self._continue, True) self._continue = None # payload eof handler payload.on_eof(self._response_eof) # response status self.version = message.version self.status = message.code self.reason = message.reason # headers self._headers = message.headers # type is CIMultiDictProxy self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes] # payload self.content = payload # cookies if cookie_hdrs := self.headers.getall(hdrs.SET_COOKIE, ()): # Store raw cookie headers for CookieJar self._raw_cookie_headers = tuple(cookie_hdrs) return self def _response_eof(self) -> None: if self._closed: return # protocol could be None because connection could be detached protocol = self._connection and self._connection.protocol if protocol is not None and protocol.upgraded: return self._closed = True self._cleanup_writer() self._release_connection() @property def closed(self) -> bool: return self._closed def close(self) -> None: if not self._released: self._notify_content() self._closed = True if self._loop.is_closed(): return self._cleanup_writer() if self._connection is not None: self._connection.close() self._connection = None def release(self) -> None: if not self._released: self._notify_content() self._closed = True self._cleanup_writer() self._release_connection() @property def ok(self) -> bool: """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not. This is **not** a check for ``200 OK`` but a check that the response status is under 400. """ return 400 > self.status def raise_for_status(self) -> None: if not self.ok: # reason should always be not None for a started response assert self.reason is not None # If we're in a context we can rely on __aexit__() to release as the # exception propagates. if not self._in_context: self.release() raise ClientResponseError( self.request_info, self.history, status=self.status, message=self.reason, headers=self.headers, ) def _release_connection(self) -> None: if self._connection is not None: if self.__writer is None: self._connection.release() self._connection = None else: self.__writer.add_done_callback(lambda f: self._release_connection()) async def _wait_released(self) -> None: if self.__writer is not None: try: await self.__writer except asyncio.CancelledError: if ( sys.version_info >= (3, 11) and (task := asyncio.current_task()) and task.cancelling() ): raise self._release_connection() def _cleanup_writer(self) -> None: if self.__writer is not None: self.__writer.cancel() self._session = None def _notify_content(self) -> None: content = self.content # content can be None here, but the types are cheated elsewhere. if content and content.exception() is None: # type: ignore[truthy-bool] set_exception(content, _CONNECTION_CLOSED_EXCEPTION) self._released = True async def wait_for_close(self) -> None: if self.__writer is not None: try: await self.__writer except asyncio.CancelledError: if ( sys.version_info >= (3, 11) and (task := asyncio.current_task()) and task.cancelling() ): raise self.release() async def read(self) -> bytes: """Read response payload.""" if self._body is None: try: self._body = await self.content.read() for trace in self._traces: await trace.send_response_chunk_received( self.method, self.url, self._body ) except BaseException: self.close() raise elif self._released: # Response explicitly released raise ClientConnectionError("Connection closed") protocol = self._connection and self._connection.protocol if protocol is None or not protocol.upgraded: await self._wait_released() # Underlying connection released return self._body def get_encoding(self) -> str: ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() mimetype = parse_mimetype(ctype) encoding = mimetype.parameters.get("charset") if encoding: with contextlib.suppress(LookupError, ValueError): return codecs.lookup(encoding).name if mimetype.type == "application" and ( mimetype.subtype == "json" or mimetype.subtype == "rdap" ): # RFC 7159 states that the default encoding is UTF-8. # RFC 7483 defines application/rdap+json return "utf-8" if self._body is None: raise RuntimeError( "Cannot compute fallback encoding of a not yet read body" ) return self._resolve_charset(self, self._body) async def text(self, encoding: str | None = None, errors: str = "strict") -> str: """Read response payload and decode.""" await self.read() if encoding is None: encoding = self.get_encoding() return self._body.decode(encoding, errors=errors) # type: ignore[union-attr] async def json( self, *, encoding: str | None = None, loads: JSONDecoder = DEFAULT_JSON_DECODER, content_type: str | None = "application/json", ) -> Any: """Read and decodes JSON response.""" await self.read() if content_type: if not is_expected_content_type(self.content_type, content_type): raise ContentTypeError( self.request_info, self.history, status=self.status, message=( "Attempt to decode JSON with " "unexpected mimetype: %s" % self.content_type ), headers=self.headers, ) if encoding is None: encoding = self.get_encoding() return loads(self._body.decode(encoding)) # type: ignore[union-attr] async def __aenter__(self) -> "ClientResponse": self._in_context = True return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: self._in_context = False # similar to _RequestContextManager, we do not need to check # for exceptions, response object can close connection # if state is broken self.release() await self.wait_for_close() class ClientRequestBase: """An internal class for proxy requests.""" POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT} auth = None proxy: URL | None = None response_class = ClientResponse server_hostname: str | None = None # Needed in connector.py version = HttpVersion11 _response = None # These class defaults help create_autospec() work correctly. # If autospec is improved in future, maybe these can be removed. url = URL() method = "GET" _writer_task: asyncio.Task[None] | None = None # async task for streaming data _skip_auto_headers: "CIMultiDict[None] | None" = None # N.B. # Adding __del__ method with self._writer closing doesn't make sense # because _writer is instance method, thus it keeps a reference to self. # Until writer has finished finalizer will not be called. def __init__( self, method: str, url: URL, *, headers: CIMultiDict[str], auth: BasicAuth | None, loop: asyncio.AbstractEventLoop, ssl: SSLContext | bool | Fingerprint, trust_env: bool = False, ): if match := _CONTAINS_CONTROL_CHAR_RE.search(method): raise ValueError( f"Method cannot contain non-token characters {method!r} " f"(found at least {match.group()!r})" ) # URL forbids subclasses, so a simple type check is enough. assert type(url) is URL, url self.original_url = url self.url = url.with_fragment(None) if url.raw_fragment else url self.method = method.upper() self.loop = loop self._ssl = ssl if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) self._update_host(url) self._update_headers(headers) self._update_auth(auth, trust_env) def _reset_writer(self, _: object = None) -> None: self._writer_task = None def _get_content_length(self) -> int | None: """Extract and validate Content-Length header value. Returns parsed Content-Length value or None if not set. Raises ValueError if header exists but cannot be parsed as an integer. """ if hdrs.CONTENT_LENGTH not in self.headers: return None content_length_hdr = self.headers[hdrs.CONTENT_LENGTH] try: return int(content_length_hdr) except ValueError: raise ValueError( f"Invalid Content-Length header: {content_length_hdr}" ) from None @property def _writer(self) -> asyncio.Task[None] | None: return self._writer_task @_writer.setter def _writer(self, writer: asyncio.Task[None]) -> None: if self._writer_task is not None: self._writer_task.remove_done_callback(self._reset_writer) self._writer_task = writer writer.add_done_callback(self._reset_writer) def is_ssl(self) -> bool: return self.url.scheme in _SSL_SCHEMES @property def ssl(self) -> "SSLContext | bool | Fingerprint": return self._ssl @property def connection_key(self) -> ConnectionKey: url = self.url return tuple.__new__( ConnectionKey, ( url.raw_host or "", url.port, url.scheme in _SSL_SCHEMES, self._ssl, None, None, None, ), ) def _update_auth(self, auth: BasicAuth | None, trust_env: bool = False) -> None: """Set basic auth.""" if auth is None: auth = self.auth if auth is None: return if not isinstance(auth, BasicAuth): raise TypeError("BasicAuth() tuple is required instead") self.headers[hdrs.AUTHORIZATION] = auth.encode() def _update_host(self, url: URL) -> None: """Update destination host, port and connection type (ssl).""" # get host/port if not url.raw_host: raise InvalidURL(url) # basic auth info if url.raw_user or url.raw_password: self.auth = BasicAuth(url.user or "", url.password or "") def _update_headers(self, headers: CIMultiDict[str]) -> None: """Update request headers.""" self.headers: CIMultiDict[str] = CIMultiDict() # Build the host header host = self.url.host_port_subcomponent # host_port_subcomponent is None when the URL is a relative URL. # but we know we do not have a relative URL here. assert host is not None self.headers[hdrs.HOST] = headers.pop(hdrs.HOST, host) self.headers.extend(headers) def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: return self.response_class( self.method, self.original_url, writer=task, continue100=None, timer=TimerNoop(), traces=(), loop=self.loop, session=None, request_headers=self.headers, original_url=self.original_url, ) def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: return StreamWriter(protocol, self.loop) def _should_write(self, protocol: BaseProtocol) -> bool: return protocol.writing_paused async def _send(self, conn: "Connection") -> ClientResponse: # Specify request target: # - CONNECT request must send authority form URI # - not CONNECT proxy must send absolute form URI # - most common is origin form URI if self.method == hdrs.METH_CONNECT: connect_host = self.url.host_subcomponent assert connect_host is not None path = f"{connect_host}:{self.url.port}" elif self.proxy and not self.is_ssl(): path = str(self.url) else: path = self.url.raw_path_qs protocol = conn.protocol assert protocol is not None writer = self._create_writer(protocol) # set default content-type if ( self.method in self.POST_METHODS and ( self._skip_auto_headers is None or hdrs.CONTENT_TYPE not in self._skip_auto_headers ) and hdrs.CONTENT_TYPE not in self.headers ): self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" v = self.version if hdrs.CONNECTION not in self.headers: if conn._connector.force_close: if v == HttpVersion11: self.headers[hdrs.CONNECTION] = "close" elif v == HttpVersion10: self.headers[hdrs.CONNECTION] = "keep-alive" # status + headers status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" # Buffer headers for potential coalescing with body await writer.write_headers(status_line, self.headers) task: asyncio.Task[None] | None if self._should_write(protocol): coro = self._write_bytes(writer, conn, self._get_content_length()) if sys.version_info >= (3, 12): # Optimization for Python 3.12, try to write # bytes immediately to avoid having to schedule # the task on the event loop. task = asyncio.Task(coro, loop=self.loop, eager_start=True) else: task = self.loop.create_task(coro) if task.done(): task = None else: self._writer = task else: # We have nothing to write because # - there is no body # - the protocol does not have writing paused # - we are not waiting for a 100-continue response protocol.start_timeout() writer.set_eof() task = None self._response = self._create_response(task) return self._response async def _write_bytes( self, writer: AbstractStreamWriter, conn: "Connection", content_length: int | None, ) -> None: # Base class never has a body, this will never be run. assert False class ClientRequestArgs(TypedDict, total=False): params: Query headers: CIMultiDict[str] skip_auto_headers: Iterable[str] | None data: Any cookies: BaseCookie[str] auth: BasicAuth | None version: HttpVersion compress: str | bool chunked: bool | None expect100: bool loop: asyncio.AbstractEventLoop response_class: type[ClientResponse] proxy: URL | None proxy_auth: BasicAuth | None timer: BaseTimerContext session: "ClientSession" ssl: SSLContext | bool | Fingerprint proxy_headers: CIMultiDict[str] | None traces: list["Trace"] trust_env: bool server_hostname: str | None class ClientRequest(ClientRequestBase): _EMPTY_BODY = payload.PAYLOAD_REGISTRY.get(b"", disposition=None) _body = _EMPTY_BODY _continue = None # waiter future for '100 Continue' response GET_METHODS = { hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS, hdrs.METH_TRACE, } DEFAULT_HEADERS = { hdrs.ACCEPT: "*/*", hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), } def __init__( self, method: str, url: URL, *, params: Query, headers: CIMultiDict[str], skip_auto_headers: Iterable[str] | None, data: Any, cookies: BaseCookie[str], auth: BasicAuth | None, version: HttpVersion, compress: str | bool, chunked: bool | None, expect100: bool, loop: asyncio.AbstractEventLoop, response_class: type[ClientResponse], proxy: URL | None, proxy_auth: BasicAuth | None, timer: BaseTimerContext, session: "ClientSession", ssl: SSLContext | bool | Fingerprint, proxy_headers: CIMultiDict[str] | None, traces: list["Trace"], trust_env: bool, server_hostname: str | None, **kwargs: object, ): # kwargs exists so authors of subclasses should expect to pass through unknown # arguments. This allows us to safely add new arguments in future releases. # But, we should never receive unknown arguments here in the parent class, this # would indicate an argument has been named wrong or similar in the subclass. assert not kwargs, "Unexpected arguments to ClientRequest" if params: url = url.extend_query(params) super().__init__(method, url, headers=headers, auth=auth, loop=loop, ssl=ssl) if proxy is not None: assert type(proxy) is URL, proxy self._session = session self.chunked = chunked self.response_class = response_class self._timer = timer self.server_hostname = server_hostname self.version = version self._update_auto_headers(skip_auto_headers) self._update_cookies(cookies) self._update_content_encoding(data, compress) self._update_proxy(proxy, proxy_auth, proxy_headers) self._update_body_from_data(data) if data is not None or self.method not in self.GET_METHODS: self._update_transfer_encoding() self._update_expect_continue(expect100) self._traces = traces @property def body(self) -> payload.Payload: return self._body @property def skip_auto_headers(self) -> CIMultiDict[None]: return self._skip_auto_headers or CIMultiDict() @property def connection_key(self) -> ConnectionKey: if proxy_headers := self.proxy_headers: h: int | None = hash(tuple(proxy_headers.items())) else: h = None url = self.url return tuple.__new__( ConnectionKey, ( url.raw_host or "", url.port, url.scheme in _SSL_SCHEMES, self._ssl, self.proxy, self.proxy_auth, h, ), ) @property def session(self) -> "ClientSession": """Return the ClientSession instance. This property provides access to the ClientSession that initiated this request, allowing middleware to make additional requests using the same session. """ return self._session def _update_auto_headers(self, skip_auto_headers: Iterable[str] | None) -> None: if skip_auto_headers is not None: self._skip_auto_headers = CIMultiDict( (hdr, None) for hdr in sorted(skip_auto_headers) ) used_headers = self.headers.copy() used_headers.extend(self._skip_auto_headers) # type: ignore[arg-type] else: # Fast path when there are no headers to skip # which is the most common case. used_headers = self.headers for hdr, val in self.DEFAULT_HEADERS.items(): if hdr not in used_headers: self.headers[hdr] = val if hdrs.USER_AGENT not in used_headers: self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE def _update_cookies(self, cookies: BaseCookie[str]) -> None: """Update request cookies header.""" if not cookies: return c = SimpleCookie() if hdrs.COOKIE in self.headers: # parse_cookie_header for RFC 6265 compliant Cookie header parsing c.update(parse_cookie_header(self.headers.get(hdrs.COOKIE, ""))) del self.headers[hdrs.COOKIE] for name, value in cookies.items(): # Use helper to preserve coded_value exactly as sent by server c[name] = preserve_morsel_with_coded_value(value) self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() def _update_content_encoding(self, data: Any, compress: bool | str) -> None: """Set request content encoding.""" self.compress = None if not data: return if self.headers.get(hdrs.CONTENT_ENCODING): if compress: raise ValueError( "compress can not be set if Content-Encoding header is set" ) elif compress: self.compress = compress if isinstance(compress, str) else "deflate" self.headers[hdrs.CONTENT_ENCODING] = self.compress self.chunked = True # enable chunked, no need to deal with length def _update_transfer_encoding(self) -> None: """Analyze transfer-encoding header.""" te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower() if "chunked" in te: if self.chunked: raise ValueError( "chunked can not be set " 'if "Transfer-Encoding: chunked" header is set' ) elif self.chunked: if hdrs.CONTENT_LENGTH in self.headers: raise ValueError( "chunked can not be set if Content-Length header is set" ) self.headers[hdrs.TRANSFER_ENCODING] = "chunked" def _update_body_from_data(self, body: Any) -> None: """Update request body from data.""" if body is None: self._body = self._EMPTY_BODY # Set Content-Length to 0 when body is None for methods that expect a body if ( self.method not in self.GET_METHODS and not self.chunked and hdrs.CONTENT_LENGTH not in self.headers ): self.headers[hdrs.CONTENT_LENGTH] = "0" return # FormData if isinstance(body, FormData): body = body() else: try: body = payload.PAYLOAD_REGISTRY.get(body, disposition=None) except payload.LookupError: boundary = None if hdrs.CONTENT_TYPE in self.headers: boundary = parse_mimetype( self.headers[hdrs.CONTENT_TYPE] ).parameters.get("boundary") body = FormData(body, boundary=boundary)() self._body = body # enable chunked encoding if needed if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers: if (size := body.size) is not None: self.headers[hdrs.CONTENT_LENGTH] = str(size) else: self.chunked = True # copy payload headers assert body.headers headers = self.headers skip_headers = self._skip_auto_headers for key, value in body.headers.items(): if key in headers or (skip_headers is not None and key in skip_headers): continue headers[key] = value def _update_body(self, body: Any) -> None: """Update request body after its already been set.""" # Remove existing Content-Length header since body is changing if hdrs.CONTENT_LENGTH in self.headers: del self.headers[hdrs.CONTENT_LENGTH] # Remove existing Transfer-Encoding header to avoid conflicts if self.chunked and hdrs.TRANSFER_ENCODING in self.headers: del self.headers[hdrs.TRANSFER_ENCODING] # Now update the body using the existing method self._update_body_from_data(body) # Update transfer encoding headers if needed (same logic as __init__) if body is not None or self.method not in self.GET_METHODS: self._update_transfer_encoding() async def update_body(self, body: Any) -> None: """ Update request body and close previous payload if needed. This method safely updates the request body by first closing any existing payload to prevent resource leaks, then setting the new body. IMPORTANT: Always use this method instead of setting request.body directly. Direct assignment to request.body will leak resources if the previous body contains file handles, streams, or other resources that need cleanup. Args: body: The new body content. Can be: - bytes/bytearray: Raw binary data - str: Text data (will be encoded using charset from Content-Type) - FormData: Form data that will be encoded as multipart/form-data - Payload: A pre-configured payload object - AsyncIterable: An async iterable of bytes chunks - File-like object: Will be read and sent as binary data - None: Clears the body Usage: # CORRECT: Use update_body await request.update_body(b"new request data") # WRONG: Don't set body directly # request.body = b"new request data" # This will leak resources! # Update with form data form_data = FormData() form_data.add_field('field', 'value') await request.update_body(form_data) # Clear body await request.update_body(None) Note: This method is async because it may need to close file handles or other resources associated with the previous payload. Always await this method to ensure proper cleanup. Warning: Setting request.body directly is highly discouraged and can lead to: - Resource leaks (unclosed file handles, streams) - Memory leaks (unreleased buffers) - Unexpected behavior with streaming payloads It is not recommended to change the payload type in middleware. If the body was already set (e.g., as bytes), it's best to keep the same type rather than converting it (e.g., to str) as this may result in unexpected behavior. See Also: - update_body_from_data: Synchronous body update without cleanup - body property: Direct body access (STRONGLY DISCOURAGED) """ # Close existing payload if it exists and needs closing if self._body is not None: await self._body.close() self._update_body(body) def _update_expect_continue(self, expect: bool = False) -> None: if expect: self.headers[hdrs.EXPECT] = "100-continue" elif ( hdrs.EXPECT in self.headers and self.headers[hdrs.EXPECT].lower() == "100-continue" ): expect = True if expect: self._continue = self.loop.create_future() def _update_proxy( self, proxy: URL | None, proxy_auth: BasicAuth | None, proxy_headers: CIMultiDict[str] | None, ) -> None: self.proxy = proxy if proxy is None: self.proxy_auth = None self.proxy_headers = None return if proxy_auth and not isinstance(proxy_auth, BasicAuth): raise ValueError("proxy_auth must be None or BasicAuth() tuple") self.proxy_auth = proxy_auth self.proxy_headers = proxy_headers def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: return self.response_class( self.method, self.original_url, writer=task, continue100=self._continue, timer=self._timer, traces=self._traces, loop=self.loop, session=self._session, request_headers=self.headers, original_url=self.original_url, ) def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: writer = StreamWriter( protocol, self.loop, on_chunk_sent=( functools.partial(self._on_chunk_request_sent, self.method, self.url) if self._traces else None ), on_headers_sent=( functools.partial(self._on_headers_request_sent, self.method, self.url) if self._traces else None ), ) if self.compress: writer.enable_compression(self.compress) if self.chunked is not None: writer.enable_chunking() return writer def _should_write(self, protocol: BaseProtocol) -> bool: return ( self.body.size != 0 or self._continue is not None or protocol.writing_paused ) async def _write_bytes( self, writer: AbstractStreamWriter, conn: "Connection", content_length: int | None, ) -> None: """ Write the request body to the connection stream. This method handles writing different types of request bodies: 1. Payload objects (using their specialized write_with_length method) 2. Bytes/bytearray objects 3. Iterable body content Args: writer: The stream writer to write the body to conn: The connection being used for this request content_length: Optional maximum number of bytes to write from the body (None means write the entire body) The method properly handles: - Waiting for 100-Continue responses if required - Content length constraints for chunked encoding - Error handling for network issues, cancellation, and other exceptions - Signaling EOF and timeout management Raises: ClientOSError: When there's an OS-level error writing the body ClientConnectionError: When there's a general connection error asyncio.CancelledError: When the operation is cancelled """ # 100 response if self._continue is not None: # Force headers to be sent before waiting for 100-continue writer.send_headers() await writer.drain() await self._continue protocol = conn.protocol assert protocol is not None try: await self._body.write_with_length(writer, content_length) except OSError as underlying_exc: reraised_exc = underlying_exc # Distinguish between timeout and other OS errors for better error reporting exc_is_not_timeout = underlying_exc.errno is not None or not isinstance( underlying_exc, asyncio.TimeoutError ) if exc_is_not_timeout: reraised_exc = ClientOSError( underlying_exc.errno, f"Can not write request body for {self.url !s}", ) set_exception(protocol, reraised_exc, underlying_exc) except asyncio.CancelledError: # Body hasn't been fully sent, so connection can't be reused conn.close() raise except Exception as underlying_exc: set_exception( protocol, ClientConnectionError( "Failed to send bytes into the underlying connection " f"{conn !s}: {underlying_exc!r}", ), underlying_exc, ) else: # Successfully wrote the body, signal EOF and start response timeout await writer.write_eof() protocol.start_timeout() async def _close(self) -> None: if self._writer_task is not None: try: await self._writer_task except asyncio.CancelledError: if ( sys.version_info >= (3, 11) and (task := asyncio.current_task()) and task.cancelling() ): raise def _terminate(self) -> None: if self._writer_task is not None: if not self.loop.is_closed(): self._writer_task.cancel() self._writer_task.remove_done_callback(self._reset_writer) self._writer_task = None async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: for trace in self._traces: await trace.send_request_chunk_sent(method, url, chunk) async def _on_headers_request_sent( self, method: str, url: URL, headers: "CIMultiDict[str]" ) -> None: for trace in self._traces: await trace.send_request_headers(method, url, headers) ================================================ FILE: aiohttp/client_ws.py ================================================ """WebSocket client for asyncio.""" import asyncio import sys from collections.abc import Callable from types import TracebackType from typing import Any, Final, Generic, Literal, overload from ._websocket.reader import WebSocketDataQueue from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError from .client_reqrep import ClientResponse from .helpers import calculate_timeout_when, frozen_dataclass_decorator, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, WebSocketError, WSCloseCode, WSMessageDecodeText, WSMessageNoDecodeText, WSMsgType, ) from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter, WSMessageError from .streams import EofStream from .typedefs import ( DEFAULT_JSON_DECODER, DEFAULT_JSON_ENCODER, JSONBytesEncoder, JSONDecoder, JSONEncoder, ) if sys.version_info >= (3, 13): from typing import TypeVar else: from typing_extensions import TypeVar if sys.version_info >= (3, 11): import asyncio as async_timeout from typing import Self else: import async_timeout from typing_extensions import Self # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) # Covariant because it only affects return types, not input types _DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) @frozen_dataclass_decorator class ClientWSTimeout: ws_receive: float | None = None ws_close: float | None = None DEFAULT_WS_CLIENT_TIMEOUT: Final[ClientWSTimeout] = ClientWSTimeout( ws_receive=None, ws_close=10.0 ) class ClientWebSocketResponse(Generic[_DecodeText]): def __init__( self, reader: WebSocketDataQueue, writer: WebSocketWriter, protocol: str | None, response: ClientResponse, timeout: ClientWSTimeout, autoclose: bool, autoping: bool, loop: asyncio.AbstractEventLoop, *, heartbeat: float | None = None, compress: int = 0, client_notakeover: bool = False, ) -> None: self._response = response self._conn = response.connection self._writer = writer self._reader = reader self._protocol = protocol self._closed = False self._closing = False self._close_code: int | None = None self._timeout = timeout self._autoclose = autoclose self._autoping = autoping self._heartbeat = heartbeat self._heartbeat_cb: asyncio.TimerHandle | None = None self._heartbeat_when: float = 0.0 if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 self._pong_response_cb: asyncio.TimerHandle | None = None self._loop = loop self._waiting: bool = False self._close_wait: asyncio.Future[None] | None = None self._exception: BaseException | None = None self._compress = compress self._client_notakeover = client_notakeover self._ping_task: asyncio.Task[None] | None = None self._need_heartbeat_reset = False self._heartbeat_reset_handle: asyncio.Handle | None = None self._reset_heartbeat() def _cancel_heartbeat(self) -> None: self._cancel_pong_response_cb() if self._heartbeat_reset_handle is not None: self._heartbeat_reset_handle.cancel() self._heartbeat_reset_handle = None self._need_heartbeat_reset = False if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None if self._ping_task is not None: self._ping_task.cancel() self._ping_task = None def _cancel_pong_response_cb(self) -> None: if self._pong_response_cb is not None: self._pong_response_cb.cancel() self._pong_response_cb = None def _on_data_received(self) -> None: if self._heartbeat is None or self._need_heartbeat_reset: return loop = self._loop assert loop is not None # Coalesce multiple chunks received in the same loop tick into a single # heartbeat reset. Resetting immediately per chunk increases timer churn. self._need_heartbeat_reset = True self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset) def _flush_heartbeat_reset(self) -> None: self._heartbeat_reset_handle = None if not self._need_heartbeat_reset: return self._reset_heartbeat() self._need_heartbeat_reset = False def _reset_heartbeat(self) -> None: if self._heartbeat is None: return self._cancel_pong_response_cb() loop = self._loop assert loop is not None conn = self._conn timeout_ceil_threshold = ( conn._connector._timeout_ceil_threshold if conn is not None else 5 ) now = loop.time() when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) self._heartbeat_when = when if self._heartbeat_cb is None: # We do not cancel the previous heartbeat_cb here because # it generates a significant amount of TimerHandle churn # which causes asyncio to rebuild the heap frequently. # Instead _send_heartbeat() will reschedule the next # heartbeat if it fires too early. self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) def _send_heartbeat(self) -> None: self._heartbeat_cb = None # If heartbeat reset is pending (data is being received), skip sending # the ping and let the reset callback handle rescheduling the heartbeat. if self._need_heartbeat_reset: return loop = self._loop now = loop.time() if now < self._heartbeat_when: # Heartbeat fired too early, reschedule self._heartbeat_cb = loop.call_at( self._heartbeat_when, self._send_heartbeat ) return conn = self._conn timeout_ceil_threshold = ( conn._connector._timeout_ceil_threshold if conn is not None else 5 ) when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) self._cancel_pong_response_cb() self._pong_response_cb = loop.call_at(when, self._pong_not_received) coro = self._writer.send_frame(b"", WSMsgType.PING) if sys.version_info >= (3, 12): # Optimization for Python 3.12, try to send the ping # immediately to avoid having to schedule # the task on the event loop. ping_task = asyncio.Task(coro, loop=loop, eager_start=True) else: ping_task = loop.create_task(coro) if not ping_task.done(): self._ping_task = ping_task ping_task.add_done_callback(self._ping_task_done) else: self._ping_task_done(ping_task) def _ping_task_done(self, task: "asyncio.Task[None]") -> None: """Callback for when the ping task completes.""" if not task.cancelled() and (exc := task.exception()): self._handle_ping_pong_exception(exc) self._ping_task = None def _pong_not_received(self) -> None: self._handle_ping_pong_exception( ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds") ) def _handle_ping_pong_exception(self, exc: BaseException) -> None: """Handle exceptions raised during ping/pong processing.""" if self._closed: return self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE self._exception = exc self._response.close() if self._waiting and not self._closing: self._reader.feed_data(WSMessageError(data=exc, extra=None)) def _set_closed(self) -> None: """Set the connection to closed. Cancel any heartbeat timers and set the closed flag. """ self._closed = True self._cancel_heartbeat() def _set_closing(self) -> None: """Set the connection to closing. Cancel any heartbeat timers and set the closing flag. """ self._closing = True self._cancel_heartbeat() @property def closed(self) -> bool: return self._closed @property def close_code(self) -> int | None: return self._close_code @property def protocol(self) -> str | None: return self._protocol @property def compress(self) -> int: return self._compress @property def client_notakeover(self) -> bool: return self._client_notakeover def get_extra_info(self, name: str, default: Any = None) -> Any: """extra info from connection transport""" conn = self._response.connection if conn is None: return default transport = conn.transport if transport is None: return default return transport.get_extra_info(name, default) def exception(self) -> BaseException | None: return self._exception async def ping(self, message: bytes = b"") -> None: await self._writer.send_frame(message, WSMsgType.PING) async def pong(self, message: bytes = b"") -> None: await self._writer.send_frame(message, WSMsgType.PONG) async def send_frame( self, message: bytes, opcode: WSMsgType, compress: int | None = None ) -> None: """Send a frame over the websocket.""" await self._writer.send_frame(message, opcode, compress) async def send_str(self, data: str, compress: int | None = None) -> None: if not isinstance(data, str): raise TypeError("data argument must be str (%r)" % type(data)) await self._writer.send_frame( data.encode("utf-8"), WSMsgType.TEXT, compress=compress ) async def send_bytes(self, data: bytes, compress: int | None = None) -> None: if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError("data argument must be byte-ish (%r)" % type(data)) await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) async def send_json( self, data: Any, compress: int | None = None, *, dumps: JSONEncoder = DEFAULT_JSON_ENCODER, ) -> None: await self.send_str(dumps(data), compress=compress) async def send_json_bytes( self, data: Any, compress: int | None = None, *, dumps: JSONBytesEncoder, ) -> None: """Send JSON data using a bytes-returning encoder as a binary frame. Use this when your JSON encoder (like orjson) returns bytes instead of str, avoiding the encode/decode overhead. """ await self.send_bytes(dumps(data), compress=compress) async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: # we need to break `receive()` cycle first, # `close()` may be called from different task if self._waiting and not self._closing: assert self._loop is not None self._close_wait = self._loop.create_future() self._set_closing() self._reader.feed_data(WS_CLOSING_MESSAGE) await self._close_wait if self._closed: return False self._set_closed() try: await self._writer.close(code, message) except asyncio.CancelledError: self._close_code = WSCloseCode.ABNORMAL_CLOSURE self._response.close() raise except Exception as exc: self._close_code = WSCloseCode.ABNORMAL_CLOSURE self._exception = exc self._response.close() return True if self._close_code: self._response.close() return True while True: try: async with async_timeout.timeout(self._timeout.ws_close): msg = await self._reader.read() except asyncio.CancelledError: self._close_code = WSCloseCode.ABNORMAL_CLOSURE self._response.close() raise except Exception as exc: self._close_code = WSCloseCode.ABNORMAL_CLOSURE self._exception = exc self._response.close() return True if msg.type is WSMsgType.CLOSE: self._close_code = msg.data self._response.close() return True @overload async def receive( self: "ClientWebSocketResponse[Literal[True]]", timeout: float | None = None ) -> WSMessageDecodeText: ... @overload async def receive( self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None ) -> WSMessageNoDecodeText: ... @overload async def receive( self: "ClientWebSocketResponse[_DecodeText]", timeout: float | None = None ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... async def receive( self, timeout: float | None = None ) -> WSMessageDecodeText | WSMessageNoDecodeText: receive_timeout = timeout or self._timeout.ws_receive while True: if self._waiting: raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: return WS_CLOSED_MESSAGE elif self._closing: await self.close() return WS_CLOSED_MESSAGE try: self._waiting = True try: if receive_timeout: # Entering the context manager and creating # Timeout() object can take almost 50% of the # run time in this loop so we avoid it if # there is no read timeout. async with async_timeout.timeout(receive_timeout): msg = await self._reader.read() else: msg = await self._reader.read() finally: self._waiting = False if self._close_wait: set_result(self._close_wait, None) except (asyncio.CancelledError, asyncio.TimeoutError): self._close_code = WSCloseCode.ABNORMAL_CLOSURE raise except EofStream: self._close_code = WSCloseCode.OK await self.close() return WS_CLOSED_MESSAGE except ClientError: # Likely ServerDisconnectedError when connection is lost self._set_closed() self._close_code = WSCloseCode.ABNORMAL_CLOSURE return WS_CLOSED_MESSAGE except WebSocketError as exc: self._close_code = exc.code await self.close(code=exc.code) return WSMessageError(data=exc) except Exception as exc: self._exception = exc self._set_closing() self._close_code = WSCloseCode.ABNORMAL_CLOSURE await self.close() return WSMessageError(data=exc) if msg.type not in _INTERNAL_RECEIVE_TYPES: # If its not a close/closing/ping/pong message # we can return it immediately return msg if msg.type is WSMsgType.CLOSE: self._set_closing() self._close_code = msg.data # Could be closed elsewhere while awaiting reader if not self._closed and self._autoclose: # type: ignore[redundant-expr] await self.close() elif msg.type is WSMsgType.CLOSING: self._set_closing() elif msg.type is WSMsgType.PING and self._autoping: await self.pong(msg.data) continue elif msg.type is WSMsgType.PONG and self._autoping: continue return msg @overload async def receive_str( self: "ClientWebSocketResponse[Literal[True]]", *, timeout: float | None = None ) -> str: ... @overload async def receive_str( self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None ) -> bytes: ... @overload async def receive_str( self: "ClientWebSocketResponse[_DecodeText]", *, timeout: float | None = None ) -> str | bytes: ... async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. Returns str when decode_text=True (default), bytes when decode_text=False. """ msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" ) return msg.data async def receive_bytes(self, *, timeout: float | None = None) -> bytes: msg = await self.receive(timeout) if msg.type is not WSMsgType.BINARY: raise WSMessageTypeError( f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY" ) return msg.data @overload async def receive_json( self: "ClientWebSocketResponse[Literal[True]]", *, loads: JSONDecoder = ..., timeout: float | None = None, ) -> Any: ... @overload async def receive_json( self: "ClientWebSocketResponse[Literal[False]]", *, loads: Callable[[bytes], Any] = ..., timeout: float | None = None, ) -> Any: ... @overload async def receive_json( self: "ClientWebSocketResponse[_DecodeText]", *, loads: JSONDecoder | Callable[[bytes], Any] = ..., timeout: float | None = None, ) -> Any: ... async def receive_json( self, *, loads: JSONDecoder | Callable[[bytes], Any] = DEFAULT_JSON_DECODER, timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) return loads(data) # type: ignore[arg-type] def __aiter__(self) -> Self: return self @overload async def __anext__( self: "ClientWebSocketResponse[Literal[True]]", ) -> WSMessageDecodeText: ... @overload async def __anext__( self: "ClientWebSocketResponse[Literal[False]]", ) -> WSMessageNoDecodeText: ... @overload async def __anext__( self: "ClientWebSocketResponse[_DecodeText]", ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): raise StopAsyncIteration return msg async def __aenter__(self) -> Self: return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: await self.close() ================================================ FILE: aiohttp/compression_utils.py ================================================ import asyncio import sys import zlib from abc import ABC, abstractmethod from concurrent.futures import Executor from typing import Any, Final, Protocol, TypedDict, cast if sys.version_info >= (3, 12): from collections.abc import Buffer else: from typing import Union Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] try: try: import brotlicffi as brotli except ImportError: import brotli HAS_BROTLI = True except ImportError: HAS_BROTLI = False try: if sys.version_info >= (3, 14): from compression.zstd import ZstdDecompressor # noqa: I900 else: # TODO(PY314): Remove mentions of backports.zstd across codebase from backports.zstd import ZstdDecompressor HAS_ZSTD = True except ImportError: HAS_ZSTD = False MAX_SYNC_CHUNK_SIZE = 4096 DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB # Unlimited decompression constants - different libraries use different conventions ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited class ZLibCompressObjProtocol(Protocol): def compress(self, data: Buffer) -> bytes: ... def flush(self, mode: int = ..., /) -> bytes: ... class ZLibDecompressObjProtocol(Protocol): def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ... def flush(self, length: int = ..., /) -> bytes: ... @property def eof(self) -> bool: ... class ZLibBackendProtocol(Protocol): MAX_WBITS: int Z_FULL_FLUSH: int Z_SYNC_FLUSH: int Z_BEST_SPEED: int Z_FINISH: int def compressobj( self, level: int = ..., method: int = ..., wbits: int = ..., memLevel: int = ..., strategy: int = ..., zdict: Buffer | None = ..., ) -> ZLibCompressObjProtocol: ... def decompressobj( self, wbits: int = ..., zdict: Buffer = ... ) -> ZLibDecompressObjProtocol: ... def compress( self, data: Buffer, /, level: int = ..., wbits: int = ... ) -> bytes: ... def decompress( self, data: Buffer, /, wbits: int = ..., bufsize: int = ... ) -> bytes: ... class CompressObjArgs(TypedDict, total=False): wbits: int strategy: int level: int class ZLibBackendWrapper: def __init__(self, _zlib_backend: ZLibBackendProtocol): self._zlib_backend: ZLibBackendProtocol = _zlib_backend @property def name(self) -> str: return getattr(self._zlib_backend, "__name__", "undefined") @property def MAX_WBITS(self) -> int: return self._zlib_backend.MAX_WBITS @property def Z_FULL_FLUSH(self) -> int: return self._zlib_backend.Z_FULL_FLUSH @property def Z_SYNC_FLUSH(self) -> int: return self._zlib_backend.Z_SYNC_FLUSH @property def Z_BEST_SPEED(self) -> int: return self._zlib_backend.Z_BEST_SPEED @property def Z_FINISH(self) -> int: return self._zlib_backend.Z_FINISH def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol: return self._zlib_backend.compressobj(*args, **kwargs) def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol: return self._zlib_backend.decompressobj(*args, **kwargs) def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: return self._zlib_backend.compress(data, *args, **kwargs) def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: return self._zlib_backend.decompress(data, *args, **kwargs) # Everything not explicitly listed in the Protocol we just pass through def __getattr__(self, attrname: str) -> Any: return getattr(self._zlib_backend, attrname) ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib) def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None: ZLibBackend._zlib_backend = new_zlib_backend def encoding_to_mode( encoding: str | None = None, suppress_deflate_header: bool = False, ) -> int: if encoding == "gzip": return 16 + ZLibBackend.MAX_WBITS return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS class DecompressionBaseHandler(ABC): def __init__( self, executor: Executor | None = None, max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): """Base class for decompression handlers.""" self._executor = executor self._max_sync_chunk_size = max_sync_chunk_size @abstractmethod def decompress_sync( self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED ) -> bytes: """Decompress the given data.""" async def decompress( self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED ) -> bytes: """Decompress the given data.""" if ( self._max_sync_chunk_size is not None and len(data) > self._max_sync_chunk_size ): return await asyncio.get_event_loop().run_in_executor( self._executor, self.decompress_sync, data, max_length ) return self.decompress_sync(data, max_length) class ZLibCompressor: def __init__( self, encoding: str | None = None, suppress_deflate_header: bool = False, level: int | None = None, wbits: int | None = None, strategy: int | None = None, executor: Executor | None = None, max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): self._executor = executor self._max_sync_chunk_size = max_sync_chunk_size self._mode = ( encoding_to_mode(encoding, suppress_deflate_header) if wbits is None else wbits ) self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) kwargs: CompressObjArgs = {} kwargs["wbits"] = self._mode if strategy is not None: kwargs["strategy"] = strategy if level is not None: kwargs["level"] = level self._compressor = self._zlib_backend.compressobj(**kwargs) def compress_sync(self, data: Buffer) -> bytes: return self._compressor.compress(data) async def compress(self, data: Buffer) -> bytes: """Compress the data and returned the compressed bytes. Note that flush() must be called after the last call to compress() If the data size is large than the max_sync_chunk_size, the compression will be done in the executor. Otherwise, the compression will be done in the event loop. **WARNING: This method is NOT cancellation-safe when used with flush().** If this operation is cancelled, the compressor state may be corrupted. The connection MUST be closed after cancellation to avoid data corruption in subsequent compress operations. For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap compress() + flush() + send operations in a shield and lock to ensure atomicity. """ # For large payloads, offload compression to executor to avoid blocking event loop should_use_executor = ( self._max_sync_chunk_size is not None and len(data) > self._max_sync_chunk_size ) if should_use_executor: return await asyncio.get_running_loop().run_in_executor( self._executor, self._compressor.compress, data ) return self.compress_sync(data) def flush(self, mode: int | None = None) -> bytes: """Flush the compressor synchronously. **WARNING: This method is NOT cancellation-safe when called after compress().** The flush() operation accesses shared compressor state. If compress() was cancelled, calling flush() may result in corrupted data. The connection MUST be closed after compress() cancellation. For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap compress() + flush() + send operations in a shield and lock to ensure atomicity. """ return self._compressor.flush( mode if mode is not None else self._zlib_backend.Z_FINISH ) class ZLibDecompressor(DecompressionBaseHandler): def __init__( self, encoding: str | None = None, suppress_deflate_header: bool = False, executor: Executor | None = None, max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) self._mode = encoding_to_mode(encoding, suppress_deflate_header) self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode) def decompress_sync( self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED ) -> bytes: return self._decompressor.decompress(data, max_length) def flush(self, length: int = 0) -> bytes: return ( self._decompressor.flush(length) if length > 0 else self._decompressor.flush() ) @property def eof(self) -> bool: return self._decompressor.eof class BrotliDecompressor(DecompressionBaseHandler): # Supports both 'brotlipy' and 'Brotli' packages # since they share an import name. The top branches # are for 'brotlipy' and bottom branches for 'Brotli' def __init__( self, executor: Executor | None = None, max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ) -> None: """Decompress data using the Brotli library.""" if not HAS_BROTLI: raise RuntimeError( "The brotli decompression is not available. " "Please install `Brotli` module" ) self._obj = brotli.Decompressor() super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) def decompress_sync( self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED ) -> bytes: """Decompress the given data.""" if hasattr(self._obj, "decompress"): return cast(bytes, self._obj.decompress(data, max_length)) return cast(bytes, self._obj.process(data, max_length)) def flush(self) -> bytes: """Flush the decompressor.""" if hasattr(self._obj, "flush"): return cast(bytes, self._obj.flush()) return b"" class ZSTDDecompressor(DecompressionBaseHandler): def __init__( self, executor: Executor | None = None, max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ) -> None: if not HAS_ZSTD: raise RuntimeError( "The zstd decompression is not available. " "Please install `backports.zstd` module" ) self._obj = ZstdDecompressor() super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) def decompress_sync( self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED ) -> bytes: # zstd uses -1 for unlimited, while zlib uses 0 for unlimited # Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited) zstd_max_length = ( ZSTD_MAX_LENGTH_UNLIMITED if max_length == ZLIB_MAX_LENGTH_UNLIMITED else max_length ) return self._obj.decompress(data, zstd_max_length) def flush(self) -> bytes: return b"" ================================================ FILE: aiohttp/connector.py ================================================ import asyncio import functools import random import socket import sys import traceback import warnings from collections import OrderedDict, defaultdict, deque from collections.abc import Awaitable, Callable, Iterator, Sequence from contextlib import suppress from http import HTTPStatus from itertools import chain, cycle, islice from time import monotonic from types import TracebackType from typing import TYPE_CHECKING, Any, Literal, cast import aiohappyeyeballs from aiohappyeyeballs import AddrInfoType, SocketFactoryType from multidict import CIMultiDict from . import hdrs, helpers from .abc import AbstractResolver, ResolveResult from .client_exceptions import ( ClientConnectionError, ClientConnectorCertificateError, ClientConnectorDNSError, ClientConnectorError, ClientConnectorSSLError, ClientHttpProxyError, ClientProxyConnectionError, ServerFingerprintMismatch, UnixClientConnectorError, cert_errors, ssl_errors, ) from .client_proto import ResponseHandler from .client_reqrep import ( SSL_ALLOWED_TYPES, ClientRequest, ClientRequestBase, Fingerprint, ) from .helpers import ( _SENTINEL, ceil_timeout, is_ip_address, sentinel, set_exception, set_result, ) from .log import client_logger from .resolver import DefaultResolver if sys.version_info >= (3, 12): from collections.abc import Buffer else: Buffer = "bytes | bytearray | memoryview[int] | memoryview[bytes]" try: import ssl SSLContext = ssl.SSLContext except ImportError: # pragma: no cover ssl = None # type: ignore[assignment] SSLContext = object # type: ignore[misc,assignment] EMPTY_SCHEMA_SET = frozenset({""}) HTTP_SCHEMA_SET = frozenset({"http", "https"}) WS_SCHEMA_SET = frozenset({"ws", "wss"}) HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET NEEDS_CLEANUP_CLOSED = (3, 13, 0) <= sys.version_info < ( 3, 13, 1, ) or sys.version_info < (3, 12, 8) # Cleanup closed is no longer needed after https://github.com/python/cpython/pull/118960 # which first appeared in Python 3.12.8 and 3.13.1 __all__ = ( "BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector", "AddrInfoType", "SocketFactoryType", ) if TYPE_CHECKING: from .client import ClientTimeout from .client_reqrep import ConnectionKey from .tracing import Trace class Connection: """Represents a single connection.""" __slots__ = ( "_key", "_connector", "_loop", "_protocol", "_callbacks", "_source_traceback", ) def __init__( self, connector: "BaseConnector", key: "ConnectionKey", protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: self._key = key self._connector = connector self._loop = loop self._protocol: ResponseHandler | None = protocol self._callbacks: list[Callable[[], None]] = [] self._source_traceback = ( traceback.extract_stack(sys._getframe(1)) if loop.get_debug() else None ) def __repr__(self) -> str: return f"Connection<{self._key}>" def __del__(self, _warnings: Any = warnings) -> None: if self._protocol is not None: _warnings.warn( f"Unclosed connection {self!r}", ResourceWarning, source=self ) if self._loop.is_closed(): return self._connector._release(self._key, self._protocol, should_close=True) context = {"client_connection": self, "message": "Unclosed connection"} if self._source_traceback is not None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) def __bool__(self) -> Literal[True]: """Force subclasses to not be falsy, to make checks simpler.""" return True @property def transport(self) -> asyncio.Transport | None: if self._protocol is None: return None return self._protocol.transport @property def protocol(self) -> ResponseHandler | None: return self._protocol def add_callback(self, callback: Callable[[], None]) -> None: if callback is not None: self._callbacks.append(callback) def _notify_release(self) -> None: callbacks, self._callbacks = self._callbacks[:], [] for cb in callbacks: with suppress(Exception): cb() def close(self) -> None: self._notify_release() if self._protocol is not None: self._connector._release(self._key, self._protocol, should_close=True) self._protocol = None def release(self) -> None: self._notify_release() if self._protocol is not None: self._connector._release(self._key, self._protocol) self._protocol = None @property def closed(self) -> bool: return self._protocol is None or not self._protocol.is_connected() class _ConnectTunnelConnection(Connection): """Special connection wrapper for CONNECT tunnels that must never be pooled. This connection wraps the proxy connection that will be upgraded with TLS. It must never be released to the pool because: 1. Its 'closed' future will never complete, causing session.close() to hang 2. It represents an intermediate state, not a reusable connection 3. The real connection (with TLS) will be created separately """ def release(self) -> None: """Do nothing - don't pool or close the connection. These connections are an intermediate state during the CONNECT tunnel setup and will be cleaned up naturally after the TLS upgrade. If they were to be pooled, they would never be properly closed, causing session.close() to wait forever for their 'closed' future. """ class _TransportPlaceholder: """placeholder for BaseConnector.connect function""" __slots__ = ("closed", "transport") def __init__(self, closed_future: asyncio.Future[Exception | None]) -> None: """Initialize a placeholder for a transport.""" self.closed = closed_future self.transport = None def close(self) -> None: """Close the placeholder.""" def abort(self) -> None: """Abort the placeholder (does nothing).""" class BaseConnector: """Base connector class. keepalive_timeout - (optional) Keep-alive timeout. force_close - Set to True to force close and do reconnect after each request (and between redirects). limit - The total number of simultaneous connections. limit_per_host - Number of simultaneous connections to one host. enable_cleanup_closed - Enables clean-up closed ssl transports. Disabled by default. timeout_ceil_threshold - Trigger ceiling of timeout values when it's above timeout_ceil_threshold. loop - Optional event loop. """ _closed = True # prevent AttributeError in __del__ if ctor was failed _source_traceback = None # abort transport after 2 seconds (cleanup broken connections) _cleanup_closed_period = 2.0 allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET def __init__( self, *, keepalive_timeout: _SENTINEL | None | float = sentinel, force_close: bool = False, limit: int = 100, limit_per_host: int = 0, enable_cleanup_closed: bool = False, timeout_ceil_threshold: float = 5, ) -> None: if force_close: if keepalive_timeout is not None and keepalive_timeout is not sentinel: raise ValueError( "keepalive_timeout cannot be set if force_close is True" ) else: if keepalive_timeout is sentinel: keepalive_timeout = 15.0 self._timeout_ceil_threshold = timeout_ceil_threshold loop = asyncio.get_running_loop() self._closed = False if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) # Connection pool of reusable connections. # We use a deque to store connections because it has O(1) popleft() # and O(1) append() operations to implement a FIFO queue. self._conns: defaultdict[ ConnectionKey, deque[tuple[ResponseHandler, float]] ] = defaultdict(deque) self._limit = limit self._limit_per_host = limit_per_host self._acquired: set[ResponseHandler] = set() self._acquired_per_host: defaultdict[ConnectionKey, set[ResponseHandler]] = ( defaultdict(set) ) self._keepalive_timeout = cast(float, keepalive_timeout) self._force_close = force_close # {host_key: FIFO list of waiters} # The FIFO is implemented with an OrderedDict with None keys because # python does not have an ordered set. self._waiters: defaultdict[ ConnectionKey, OrderedDict[asyncio.Future[None], None] ] = defaultdict(OrderedDict) self._loop = loop self._factory = functools.partial(ResponseHandler, loop=loop) # start keep-alive connection cleanup task self._cleanup_handle: asyncio.TimerHandle | None = None # start cleanup closed transports task self._cleanup_closed_handle: asyncio.TimerHandle | None = None if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED: warnings.warn( "enable_cleanup_closed ignored because " "https://github.com/python/cpython/pull/118960 is fixed " f"in Python version {sys.version_info}", DeprecationWarning, stacklevel=2, ) enable_cleanup_closed = False self._cleanup_closed_disabled = not enable_cleanup_closed self._cleanup_closed_transports: list[asyncio.Transport | None] = [] self._placeholder_future: asyncio.Future[Exception | None] = ( loop.create_future() ) self._placeholder_future.set_result(None) self._cleanup_closed() def __del__(self, _warnings: Any = warnings) -> None: if self._closed: return if not self._conns: return conns = [repr(c) for c in self._conns.values()] self._close_immediately() _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, source=self) context = { "connector": self, "connections": conns, "message": "Unclosed connector", } if self._source_traceback is not None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) async def __aenter__(self) -> "BaseConnector": return self async def __aexit__( self, exc_type: type[BaseException] | None = None, exc_value: BaseException | None = None, exc_traceback: TracebackType | None = None, ) -> None: await self.close() @property def force_close(self) -> bool: """Ultimately close connection on releasing if True.""" return self._force_close @property def limit(self) -> int: """The total number for simultaneous connections. If limit is 0 the connector has no limit. The default limit size is 100. """ return self._limit @property def limit_per_host(self) -> int: """The limit for simultaneous connections to the same endpoint. Endpoints are the same if they are have equal (host, port, is_ssl) triple. """ return self._limit_per_host def _cleanup(self) -> None: """Cleanup unused transports.""" if self._cleanup_handle: self._cleanup_handle.cancel() # _cleanup_handle should be unset, otherwise _release() will not # recreate it ever! self._cleanup_handle = None now = monotonic() timeout = self._keepalive_timeout if self._conns: connections = defaultdict(deque) deadline = now - timeout for key, conns in self._conns.items(): alive: deque[tuple[ResponseHandler, float]] = deque() for proto, use_time in conns: if proto.is_connected() and use_time - deadline >= 0: alive.append((proto, use_time)) continue transport = proto.transport proto.close() if not self._cleanup_closed_disabled and key.is_ssl: self._cleanup_closed_transports.append(transport) if alive: connections[key] = alive self._conns = connections if self._conns: self._cleanup_handle = helpers.weakref_handle( self, "_cleanup", timeout, self._loop, timeout_ceil_threshold=self._timeout_ceil_threshold, ) def _cleanup_closed(self) -> None: """Double confirmation for transport close. Some broken ssl servers may leave socket open without proper close. """ if self._cleanup_closed_handle: self._cleanup_closed_handle.cancel() for transport in self._cleanup_closed_transports: if transport is not None: transport.abort() self._cleanup_closed_transports = [] if not self._cleanup_closed_disabled: self._cleanup_closed_handle = helpers.weakref_handle( self, "_cleanup_closed", self._cleanup_closed_period, self._loop, timeout_ceil_threshold=self._timeout_ceil_threshold, ) async def close(self, *, abort_ssl: bool = False) -> None: """Close all opened transports. :param abort_ssl: If True, SSL connections will be aborted immediately without performing the shutdown handshake. This provides faster cleanup at the cost of less graceful disconnection. """ waiters = self._close_immediately(abort_ssl=abort_ssl) if waiters: results = await asyncio.gather(*waiters, return_exceptions=True) for res in results: if isinstance(res, Exception): err_msg = "Error while closing connector: " + repr(res) client_logger.debug(err_msg) def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]: waiters: list[Awaitable[object]] = [] if self._closed: return waiters self._closed = True try: if self._loop.is_closed(): return waiters # cancel cleanup task if self._cleanup_handle: self._cleanup_handle.cancel() # cancel cleanup close task if self._cleanup_closed_handle: self._cleanup_closed_handle.cancel() for data in self._conns.values(): for proto, _ in data: if ( abort_ssl and proto.transport and proto.transport.get_extra_info("sslcontext") is not None ): proto.abort() else: proto.close() if closed := proto.closed: waiters.append(closed) for proto in self._acquired: if ( abort_ssl and proto.transport and proto.transport.get_extra_info("sslcontext") is not None ): proto.abort() else: proto.close() if closed := proto.closed: waiters.append(closed) # TODO (A.Yushovskiy, 24-May-2019) collect transp. closing futures for transport in self._cleanup_closed_transports: if transport is not None: transport.abort() return waiters finally: self._conns.clear() self._acquired.clear() for keyed_waiters in self._waiters.values(): for keyed_waiter in keyed_waiters: keyed_waiter.cancel() self._waiters.clear() self._cleanup_handle = None self._cleanup_closed_transports.clear() self._cleanup_closed_handle = None @property def closed(self) -> bool: """Is connector closed. A readonly property. """ return self._closed def _available_connections(self, key: "ConnectionKey") -> int: """ Return number of available connections. The limit, limit_per_host and the connection key are taken into account. If it returns less than 1 means that there are no connections available. """ # check total available connections # If there are no limits, this will always return 1 total_remain = 1 if self._limit and (total_remain := self._limit - len(self._acquired)) <= 0: return total_remain # check limit per host if host_remain := self._limit_per_host: if acquired := self._acquired_per_host.get(key): host_remain -= len(acquired) if total_remain > host_remain: return host_remain return total_remain def _update_proxy_auth_header_and_build_proxy_req( self, req: ClientRequest ) -> ClientRequestBase: """Set Proxy-Authorization header for non-SSL proxy requests and builds the proxy request for SSL proxy requests.""" url = req.proxy assert url is not None headers = req.proxy_headers or CIMultiDict[str]() headers[hdrs.HOST] = req.headers[hdrs.HOST] proxy_req = ClientRequestBase( hdrs.METH_GET, url, headers=headers, auth=req.proxy_auth, loop=self._loop, ssl=req.ssl, ) auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) if auth is not None: if not req.is_ssl(): req.headers[hdrs.PROXY_AUTHORIZATION] = auth else: proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth return proxy_req async def connect( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> Connection: """Get from pool or create new connection.""" key = req.connection_key if (conn := await self._get(key, traces)) is not None: # If we do not have to wait and we can get a connection from the pool # we can avoid the timeout ceil logic and directly return the connection if req.proxy: self._update_proxy_auth_header_and_build_proxy_req(req) return conn async with ceil_timeout(timeout.connect, timeout.ceil_threshold): if self._available_connections(key) <= 0: await self._wait_for_available_connection(key, traces) if (conn := await self._get(key, traces)) is not None: if req.proxy: self._update_proxy_auth_header_and_build_proxy_req(req) return conn placeholder = cast( ResponseHandler, _TransportPlaceholder(self._placeholder_future) ) self._acquired.add(placeholder) if self._limit_per_host: self._acquired_per_host[key].add(placeholder) try: # Traces are done inside the try block to ensure that the # that the placeholder is still cleaned up if an exception # is raised. if traces: for trace in traces: await trace.send_connection_create_start() proto = await self._create_connection(req, traces, timeout) if traces: for trace in traces: await trace.send_connection_create_end() except BaseException: self._release_acquired(key, placeholder) raise else: if self._closed: proto.close() raise ClientConnectionError("Connector is closed.") # The connection was successfully created, drop the placeholder # and add the real connection to the acquired set. There should # be no awaits after the proto is added to the acquired set # to ensure that the connection is not left in the acquired set # on cancellation. self._acquired.remove(placeholder) self._acquired.add(proto) if self._limit_per_host: acquired_per_host = self._acquired_per_host[key] acquired_per_host.remove(placeholder) acquired_per_host.add(proto) return Connection(self, key, proto, self._loop) async def _wait_for_available_connection( self, key: "ConnectionKey", traces: list["Trace"] ) -> None: """Wait for an available connection slot.""" # We loop here because there is a race between # the connection limit check and the connection # being acquired. If the connection is acquired # between the check and the await statement, we # need to loop again to check if the connection # slot is still available. attempts = 0 while True: fut: asyncio.Future[None] = self._loop.create_future() keyed_waiters = self._waiters[key] keyed_waiters[fut] = None if attempts: # If we have waited before, we need to move the waiter # to the front of the queue as otherwise we might get # starved and hit the timeout. keyed_waiters.move_to_end(fut, last=False) try: # Traces happen in the try block to ensure that the # the waiter is still cleaned up if an exception is raised. if traces: for trace in traces: await trace.send_connection_queued_start() await fut if traces: for trace in traces: await trace.send_connection_queued_end() finally: # pop the waiter from the queue if its still # there and not already removed by _release_waiter keyed_waiters.pop(fut, None) if not self._waiters.get(key, True): del self._waiters[key] if self._available_connections(key) > 0: break attempts += 1 async def _get( self, key: "ConnectionKey", traces: list["Trace"] ) -> Connection | None: """Get next reusable connection for the key or None. The connection will be marked as acquired. """ if (conns := self._conns.get(key)) is None: return None t1 = monotonic() while conns: proto, t0 = conns.popleft() # We will we reuse the connection if its connected and # the keepalive timeout has not been exceeded if proto.is_connected() and t1 - t0 <= self._keepalive_timeout: if not conns: # The very last connection was reclaimed: drop the key del self._conns[key] self._acquired.add(proto) if self._limit_per_host: self._acquired_per_host[key].add(proto) if traces: for trace in traces: try: await trace.send_connection_reuseconn() except BaseException: self._release_acquired(key, proto) raise return Connection(self, key, proto, self._loop) # Connection cannot be reused, close it transport = proto.transport proto.close() # only for SSL transports if not self._cleanup_closed_disabled and key.is_ssl: self._cleanup_closed_transports.append(transport) # No more connections: drop the key del self._conns[key] return None def _release_waiter(self) -> None: """ Iterates over all waiters until one to be released is found. The one to be released is not finished and belongs to a host that has available connections. """ if not self._waiters: return # Having the dict keys ordered this avoids to iterate # at the same order at each call. queues = list(self._waiters) random.shuffle(queues) for key in queues: if self._available_connections(key) < 1: continue waiters = self._waiters[key] while waiters: waiter, _ = waiters.popitem(last=False) if not waiter.done(): waiter.set_result(None) return def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None: """Release acquired connection.""" if self._closed: # acquired connection is already released on connector closing return self._acquired.discard(proto) if self._limit_per_host and (conns := self._acquired_per_host.get(key)): conns.discard(proto) if not conns: del self._acquired_per_host[key] self._release_waiter() def _release( self, key: "ConnectionKey", protocol: ResponseHandler, *, should_close: bool = False, ) -> None: if self._closed: # acquired connection is already released on connector closing return self._release_acquired(key, protocol) if self._force_close or should_close or protocol.should_close: transport = protocol.transport protocol.close() if key.is_ssl and not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transport) return self._conns[key].append((protocol, monotonic())) if self._cleanup_handle is None: self._cleanup_handle = helpers.weakref_handle( self, "_cleanup", self._keepalive_timeout, self._loop, timeout_ceil_threshold=self._timeout_ceil_threshold, ) async def _create_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: raise NotImplementedError() class _DNSCacheTable: def __init__(self, ttl: float | None = None, max_size: int = 1000) -> None: self._addrs_rr: OrderedDict[ tuple[str, int], tuple[Iterator[ResolveResult], int] ] = OrderedDict() self._timestamps: dict[tuple[str, int], float] = {} self._ttl = ttl self._max_size = max_size def __contains__(self, host: object) -> bool: return host in self._addrs_rr def add(self, key: tuple[str, int], addrs: list[ResolveResult]) -> None: if key in self._addrs_rr: self._addrs_rr.move_to_end(key) self._addrs_rr[key] = (cycle(addrs), len(addrs)) if self._ttl is not None: self._timestamps[key] = monotonic() if len(self._addrs_rr) > self._max_size: oldest_key, _ = self._addrs_rr.popitem(last=False) self._timestamps.pop(oldest_key, None) def remove(self, key: tuple[str, int]) -> None: self._addrs_rr.pop(key, None) self._timestamps.pop(key, None) def clear(self) -> None: self._addrs_rr.clear() self._timestamps.clear() def next_addrs(self, key: tuple[str, int]) -> list[ResolveResult]: loop, length = self._addrs_rr[key] addrs = list(islice(loop, length)) # Consume one more element to shift internal state of `cycle` next(loop) self._addrs_rr.move_to_end(key) return addrs def expired(self, key: tuple[str, int]) -> bool: if self._ttl is None: return False return self._timestamps[key] + self._ttl < monotonic() def _make_ssl_context(verified: bool) -> SSLContext: """Create SSL context. This method is not async-friendly and should be called from a thread because it will load certificates from disk and do other blocking I/O. """ if ssl is None: # No ssl support return None # type: ignore[unreachable] if verified: sslcontext = ssl.create_default_context() else: sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.options |= ssl.OP_NO_SSLv3 sslcontext.check_hostname = False sslcontext.verify_mode = ssl.CERT_NONE sslcontext.options |= ssl.OP_NO_COMPRESSION sslcontext.set_default_verify_paths() sslcontext.set_alpn_protocols(("http/1.1",)) return sslcontext # The default SSLContext objects are created at import time # since they do blocking I/O to load certificates from disk, # and imports should always be done before the event loop starts # or in a thread. _SSL_CONTEXT_VERIFIED = _make_ssl_context(True) _SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False) class TCPConnector(BaseConnector): """TCP connector. verify_ssl - Set to True to check ssl certifications. fingerprint - Pass the binary sha256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. See also https://en.wikipedia.org/wiki/HTTP_Public_Key_Pinning resolver - Enable DNS lookups and use this resolver use_dns_cache - Use memory cache for DNS lookups. ttl_dns_cache - Max seconds having cached a DNS entry, None forever. family - socket address family local_addr - local tuple of (host, port) to bind socket to keepalive_timeout - (optional) Keep-alive timeout. force_close - Set to True to force close and do reconnect after each request (and between redirects). limit - The total number of simultaneous connections. limit_per_host - Number of simultaneous connections to one host. enable_cleanup_closed - Enables clean-up closed ssl transports. Disabled by default. happy_eyeballs_delay - This is the “Connection Attempt Delay” as defined in RFC 8305. To disable the happy eyeballs algorithm, set to None. interleave - “First Address Family Count” as defined in RFC 8305 loop - Optional event loop. socket_factory - A SocketFactoryType function that, if supplied, will be used to create sockets given an AddrInfoType. ssl_shutdown_timeout - DEPRECATED. Will be removed in aiohttp 4.0. Grace period for SSL shutdown handshake on TLS connections. Default is 0 seconds (immediate abort). This parameter allowed for a clean SSL shutdown by notifying the remote peer of connection closure, while avoiding excessive delays during connector cleanup. Note: Only takes effect on Python 3.11+. """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) def __init__( self, *, use_dns_cache: bool = True, ttl_dns_cache: int | None = 10, dns_cache_max_size: int = 1000, family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, ssl: bool | Fingerprint | SSLContext = True, local_addr: tuple[str, int] | None = None, resolver: AbstractResolver | None = None, keepalive_timeout: None | float | _SENTINEL = sentinel, force_close: bool = False, limit: int = 100, limit_per_host: int = 0, enable_cleanup_closed: bool = False, timeout_ceil_threshold: float = 5, happy_eyeballs_delay: float | None = 0.25, interleave: int | None = None, socket_factory: SocketFactoryType | None = None, ssl_shutdown_timeout: _SENTINEL | None | float = sentinel, ): super().__init__( keepalive_timeout=keepalive_timeout, force_close=force_close, limit=limit, limit_per_host=limit_per_host, enable_cleanup_closed=enable_cleanup_closed, timeout_ceil_threshold=timeout_ceil_threshold, ) if not isinstance(ssl, SSL_ALLOWED_TYPES): raise TypeError( "ssl should be SSLContext, Fingerprint, or bool, " f"got {ssl!r} instead." ) self._ssl = ssl self._resolver: AbstractResolver if resolver is None: self._resolver = DefaultResolver() self._resolver_owner = True else: self._resolver = resolver self._resolver_owner = False self._use_dns_cache = use_dns_cache self._cached_hosts = _DNSCacheTable( ttl=ttl_dns_cache, max_size=dns_cache_max_size ) self._throttle_dns_futures: dict[tuple[str, int], set[asyncio.Future[None]]] = ( {} ) self._family = family self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) self._happy_eyeballs_delay = happy_eyeballs_delay self._interleave = interleave self._resolve_host_tasks: set[asyncio.Task[list[ResolveResult]]] = set() self._socket_factory = socket_factory self._ssl_shutdown_timeout: float | None # Handle ssl_shutdown_timeout with warning for Python < 3.11 if ssl_shutdown_timeout is sentinel: self._ssl_shutdown_timeout = 0 else: # Deprecation warning for ssl_shutdown_timeout parameter warnings.warn( "The ssl_shutdown_timeout parameter is deprecated and will be removed in aiohttp 4.0", DeprecationWarning, stacklevel=2, ) if ( sys.version_info < (3, 11) and ssl_shutdown_timeout is not None and ssl_shutdown_timeout != 0 ): warnings.warn( f"ssl_shutdown_timeout={ssl_shutdown_timeout} is ignored on Python < 3.11; " "only ssl_shutdown_timeout=0 is supported. The timeout will be ignored.", RuntimeWarning, stacklevel=2, ) self._ssl_shutdown_timeout = ssl_shutdown_timeout async def close(self, *, abort_ssl: bool = False) -> None: """Close all opened transports. :param abort_ssl: If True, SSL connections will be aborted immediately without performing the shutdown handshake. If False (default), the behavior is determined by ssl_shutdown_timeout: - If ssl_shutdown_timeout=0: connections are aborted - If ssl_shutdown_timeout>0: graceful shutdown is performed """ if self._resolver_owner: await self._resolver.close() # Use abort_ssl param if explicitly set, otherwise use ssl_shutdown_timeout default await super().close(abort_ssl=abort_ssl or self._ssl_shutdown_timeout == 0) def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]: for fut in chain.from_iterable(self._throttle_dns_futures.values()): fut.cancel() waiters = super()._close_immediately(abort_ssl=abort_ssl) for t in self._resolve_host_tasks: t.cancel() waiters.append(t) return waiters @property def family(self) -> int: """Socket family like AF_INET.""" return self._family @property def use_dns_cache(self) -> bool: """True if local DNS caching is enabled.""" return self._use_dns_cache def clear_dns_cache(self, host: str | None = None, port: int | None = None) -> None: """Remove specified host/port or clear all dns local cache.""" if host is not None and port is not None: self._cached_hosts.remove((host, port)) elif host is not None or port is not None: raise ValueError("either both host and port or none of them are allowed") else: self._cached_hosts.clear() async def _resolve_host( self, host: str, port: int, traces: Sequence["Trace"] | None = None ) -> list[ResolveResult]: """Resolve host and return list of addresses.""" if is_ip_address(host): return [ { "hostname": host, "host": host, "port": port, "family": self._family, "proto": 0, "flags": 0, } ] if not self._use_dns_cache: if traces: for trace in traces: await trace.send_dns_resolvehost_start(host) res = await self._resolver.resolve(host, port, family=self._family) if traces: for trace in traces: await trace.send_dns_resolvehost_end(host) return res key = (host, port) if key in self._cached_hosts and not self._cached_hosts.expired(key): # get result early, before any await (#4014) result = self._cached_hosts.next_addrs(key) if traces: for trace in traces: await trace.send_dns_cache_hit(host) return result futures: set[asyncio.Future[None]] # # If multiple connectors are resolving the same host, we wait # for the first one to resolve and then use the result for all of them. # We use a throttle to ensure that we only resolve the host once # and then use the result for all the waiters. # if key in self._throttle_dns_futures: # get futures early, before any await (#4014) futures = self._throttle_dns_futures[key] future: asyncio.Future[None] = self._loop.create_future() futures.add(future) if traces: for trace in traces: await trace.send_dns_cache_hit(host) try: await future finally: futures.discard(future) return self._cached_hosts.next_addrs(key) # update dict early, before any await (#4014) self._throttle_dns_futures[key] = futures = set() # In this case we need to create a task to ensure that we can shield # the task from cancellation as cancelling this lookup should not cancel # the underlying lookup or else the cancel event will get broadcast to # all the waiters across all connections. # coro = self._resolve_host_with_throttle(key, host, port, futures, traces) loop = asyncio.get_running_loop() if sys.version_info >= (3, 12): # Optimization for Python 3.12, try to send immediately resolved_host_task = asyncio.Task(coro, loop=loop, eager_start=True) else: resolved_host_task = loop.create_task(coro) if not resolved_host_task.done(): self._resolve_host_tasks.add(resolved_host_task) resolved_host_task.add_done_callback(self._resolve_host_tasks.discard) try: return await asyncio.shield(resolved_host_task) except asyncio.CancelledError: def drop_exception(fut: "asyncio.Future[list[ResolveResult]]") -> None: with suppress(Exception, asyncio.CancelledError): fut.result() resolved_host_task.add_done_callback(drop_exception) raise async def _resolve_host_with_throttle( self, key: tuple[str, int], host: str, port: int, futures: set[asyncio.Future[None]], traces: Sequence["Trace"] | None, ) -> list[ResolveResult]: """Resolve host and set result for all waiters. This method must be run in a task and shielded from cancellation to avoid cancelling the underlying lookup. """ try: if traces: for trace in traces: await trace.send_dns_cache_miss(host) for trace in traces: await trace.send_dns_resolvehost_start(host) addrs = await self._resolver.resolve(host, port, family=self._family) if traces: for trace in traces: await trace.send_dns_resolvehost_end(host) self._cached_hosts.add(key, addrs) for fut in futures: set_result(fut, None) except BaseException as e: # any DNS exception is set for the waiters to raise the same exception. # This coro is always run in task that is shielded from cancellation so # we should never be propagating cancellation here. for fut in futures: set_exception(fut, e) raise finally: self._throttle_dns_futures.pop(key) return self._cached_hosts.next_addrs(key) async def _create_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: """Create connection. Has same keyword arguments as BaseEventLoop.create_connection. """ if req.proxy: _, proto = await self._create_proxy_connection(req, traces, timeout) else: _, proto = await self._create_direct_connection(req, traces, timeout) return proto def _get_ssl_context(self, req: ClientRequestBase) -> SSLContext | None: """Logic to get the correct SSL context 0. if req.ssl is false, return None 1. if ssl_context is specified in req, use it 2. if _ssl_context is specified in self, use it 3. otherwise: 1. if verify_ssl is not specified in req, use self.ssl_context (will generate a default context according to self.verify_ssl) 2. if verify_ssl is True in req, generate a default SSL context 3. if verify_ssl is False in req, generate a SSL context that won't verify """ if not req.is_ssl(): return None if ssl is None: # pragma: no cover raise RuntimeError("SSL is not supported.") sslcontext = req.ssl if isinstance(sslcontext, ssl.SSLContext): return sslcontext if sslcontext is not True: # not verified or fingerprinted return _SSL_CONTEXT_UNVERIFIED sslcontext = self._ssl if isinstance(sslcontext, ssl.SSLContext): return sslcontext if sslcontext is not True: # not verified or fingerprinted return _SSL_CONTEXT_UNVERIFIED return _SSL_CONTEXT_VERIFIED def _get_fingerprint(self, req: ClientRequestBase) -> "Fingerprint | None": ret = req.ssl if isinstance(ret, Fingerprint): return ret ret = self._ssl if isinstance(ret, Fingerprint): return ret return None async def _wrap_create_connection( self, *args: Any, addr_infos: list[AddrInfoType], req: ClientRequestBase, timeout: "ClientTimeout", client_error: type[Exception] = ClientConnectorError, **kwargs: Any, ) -> tuple[asyncio.Transport, ResponseHandler]: try: async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): sock = await aiohappyeyeballs.start_connection( addr_infos=addr_infos, local_addr_infos=self._local_addr_infos, happy_eyeballs_delay=self._happy_eyeballs_delay, interleave=self._interleave, loop=self._loop, socket_factory=self._socket_factory, ) # Add ssl_shutdown_timeout for Python 3.11+ when SSL is used if ( kwargs.get("ssl") and self._ssl_shutdown_timeout and sys.version_info >= (3, 11) ): kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout return await self._loop.create_connection(*args, **kwargs, sock=sock) except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc except ssl_errors as exc: raise ClientConnectorSSLError(req.connection_key, exc) from exc except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): raise raise client_error(req.connection_key, exc) from exc def _warn_about_tls_in_tls( self, underlying_transport: asyncio.Transport, req: ClientRequest, ) -> None: """Issue a warning if the requested URL has HTTPS scheme.""" if req.url.scheme != "https": return # TLS-in-TLS only applies when the proxy itself is HTTPS. # When the proxy is HTTP, start_tls upgrades a plain TCP connection, # which is standard TLS and works on all event loops and Python versions. if req.proxy is None or req.proxy.scheme != "https": return # Check if uvloop is being used, which supports TLS in TLS, # otherwise assume that asyncio's native transport is being used. if type(underlying_transport).__module__.startswith("uvloop"): return # Support in asyncio was added in Python 3.11 (bpo-44011) asyncio_supports_tls_in_tls = sys.version_info >= (3, 11) or getattr( underlying_transport, "_start_tls_compatible", False, ) if asyncio_supports_tls_in_tls: return warnings.warn( "An HTTPS request is being sent through an HTTPS proxy. " "This support for TLS in TLS is known to be disabled " "in the stdlib asyncio. This is why you'll probably see " "an error in the log below.\n\n" "It is possible to enable it via monkeypatching. " "For more details, see:\n" "* https://bugs.python.org/issue37179\n" "* https://github.com/python/cpython/pull/28073\n\n" "You can temporarily patch this as follows:\n" "* https://docs.aiohttp.org/en/stable/client_advanced.html#proxy-support\n" "* https://github.com/aio-libs/aiohttp/discussions/6044\n", RuntimeWarning, source=self, # Why `4`? At least 3 of the calls in the stack originate # from the methods in this class. stacklevel=3, ) async def _start_tls_connection( self, underlying_transport: asyncio.Transport, req: ClientRequest, timeout: "ClientTimeout", client_error: type[Exception] = ClientConnectorError, ) -> tuple[asyncio.BaseTransport, ResponseHandler]: """Wrap the raw TCP transport with TLS.""" tls_proto = self._factory() # Create a brand new proto for TLS sslcontext = self._get_ssl_context(req) if TYPE_CHECKING: # _start_tls_connection is unreachable in the current code path # if sslcontext is None. assert sslcontext is not None try: async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): try: # ssl_shutdown_timeout is only available in Python 3.11+ if sys.version_info >= (3, 11) and self._ssl_shutdown_timeout: tls_transport = await self._loop.start_tls( underlying_transport, tls_proto, sslcontext, server_hostname=req.server_hostname or req.url.raw_host, ssl_handshake_timeout=timeout.total, ssl_shutdown_timeout=self._ssl_shutdown_timeout, ) else: tls_transport = await self._loop.start_tls( underlying_transport, tls_proto, sslcontext, server_hostname=req.server_hostname or req.url.raw_host, ssl_handshake_timeout=timeout.total, ) except BaseException: # We need to close the underlying transport since # `start_tls()` probably failed before it had a # chance to do this: if self._ssl_shutdown_timeout == 0: underlying_transport.abort() else: underlying_transport.close() raise if isinstance(tls_transport, asyncio.Transport): fingerprint = self._get_fingerprint(req) if fingerprint: try: fingerprint.check(tls_transport) except ServerFingerprintMismatch: tls_transport.close() if not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(tls_transport) raise except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc except ssl_errors as exc: raise ClientConnectorSSLError(req.connection_key, exc) from exc except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): raise raise client_error(req.connection_key, exc) from exc except TypeError as type_err: # Example cause looks like this: # TypeError: transport is not supported by start_tls() raise ClientConnectionError( "Cannot initialize a TLS-in-TLS connection to host " f"{req.url.host!s}:{req.url.port:d} through an underlying connection " f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} " f"[{type_err!s}]" ) from type_err else: if tls_transport is None: msg = "Failed to start TLS (possibly caused by closing transport)" raise client_error(req.connection_key, OSError(msg)) tls_proto.connection_made( tls_transport ) # Kick the state machine of the new TLS protocol return tls_transport, tls_proto def _convert_hosts_to_addr_infos( self, hosts: list[ResolveResult] ) -> list[AddrInfoType]: """Converts the list of hosts to a list of addr_infos. The list of hosts is the result of a DNS lookup. The list of addr_infos is the result of a call to `socket.getaddrinfo()`. """ addr_infos: list[AddrInfoType] = [] for hinfo in hosts: host = hinfo["host"] is_ipv6 = ":" in host family = socket.AF_INET6 if is_ipv6 else socket.AF_INET if self._family and self._family != family: continue addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"]) addr_infos.append( (family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr) ) return addr_infos async def _create_direct_connection( self, req: ClientRequestBase, traces: list["Trace"], timeout: "ClientTimeout", *, client_error: type[Exception] = ClientConnectorError, ) -> tuple[asyncio.Transport, ResponseHandler]: sslcontext = self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) host = req.url.raw_host assert host is not None # Replace multiple trailing dots with a single one. # A trailing dot is only present for fully-qualified domain names. # See https://github.com/aio-libs/aiohttp/pull/7364. if host.endswith(".."): host = host.rstrip(".") + "." port = req.url.port assert port is not None try: # Cancelling this lookup should not cancel the underlying lookup # or else the cancel event will get broadcast to all the waiters # across all connections. hosts = await self._resolve_host(host, port, traces=traces) except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): raise # in case of proxy it is not ClientProxyConnectionError # it is problem of resolving proxy ip itself raise ClientConnectorDNSError(req.connection_key, exc) from exc last_exc: Exception | None = None addr_infos = self._convert_hosts_to_addr_infos(hosts) while addr_infos: # Strip trailing dots, certificates contain FQDN without dots. # See https://github.com/aio-libs/aiohttp/issues/3636 server_hostname = ( (req.server_hostname or host).rstrip(".") if sslcontext else None ) try: transp, proto = await self._wrap_create_connection( self._factory, timeout=timeout, ssl=sslcontext, addr_infos=addr_infos, server_hostname=server_hostname, req=req, client_error=client_error, ) except (ClientConnectorError, asyncio.TimeoutError) as exc: last_exc = exc aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave) continue if req.is_ssl() and fingerprint: try: fingerprint.check(transp) except ServerFingerprintMismatch as exc: transp.close() if not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transp) last_exc = exc # Remove the bad peer from the list of addr_infos sock: socket.socket = transp.get_extra_info("socket") bad_peer = sock.getpeername() aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer) continue return transp, proto assert last_exc is not None raise last_exc async def _create_proxy_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> tuple[asyncio.BaseTransport, ResponseHandler]: proxy_req = self._update_proxy_auth_header_and_build_proxy_req(req) # create connection to proxy server transport, proto = await self._create_direct_connection( proxy_req, [], timeout, client_error=ClientProxyConnectionError ) if req.is_ssl(): self._warn_about_tls_in_tls(transport, req) # For HTTPS requests over HTTP proxy # we must notify proxy to tunnel connection # so we send CONNECT command: # CONNECT www.python.org:443 HTTP/1.1 # Host: www.python.org # # next we must do TLS handshake and so on # to do this we must wrap raw socket into secure one # asyncio handles this perfectly proxy_req.method = hdrs.METH_CONNECT proxy_req.url = req.url key = req.connection_key._replace( proxy=None, proxy_auth=None, proxy_headers_hash=None ) conn = _ConnectTunnelConnection(self, key, proto, self._loop) proxy_resp = await proxy_req._send(conn) try: protocol = conn._protocol assert protocol is not None # read_until_eof=True will ensure the connection isn't closed # once the response is received and processed allowing # START_TLS to work on the connection below. protocol.set_response_params( read_until_eof=True, timeout_ceil_threshold=self._timeout_ceil_threshold, ) resp = await proxy_resp.start(conn) except BaseException: proxy_resp.close() conn.close() raise else: conn._protocol = None try: if resp.status != 200: message = resp.reason if message is None: message = HTTPStatus(resp.status).phrase raise ClientHttpProxyError( proxy_resp.request_info, resp.history, status=resp.status, message=message, headers=resp.headers, ) except BaseException: # It shouldn't be closed in `finally` because it's fed to # `loop.start_tls()` and the docs say not to touch it after # passing there. transport.close() raise return await self._start_tls_connection( # Access the old transport for the last time before it's # closed and forgotten forever: transport, req=req, timeout=timeout, ) finally: proxy_resp.close() return transport, proto class UnixConnector(BaseConnector): """Unix socket connector. path - Unix socket path. keepalive_timeout - (optional) Keep-alive timeout. force_close - Set to True to force close and do reconnect after each request (and between redirects). limit - The total number of simultaneous connections. limit_per_host - Number of simultaneous connections to one host. loop - Optional event loop. """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"}) def __init__( self, path: str, force_close: bool = False, keepalive_timeout: _SENTINEL | float | None = sentinel, limit: int = 100, limit_per_host: int = 0, ) -> None: super().__init__( force_close=force_close, keepalive_timeout=keepalive_timeout, limit=limit, limit_per_host=limit_per_host, ) self._path = path @property def path(self) -> str: """Path to unix socket.""" return self._path async def _create_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): _, proto = await self._loop.create_unix_connection( self._factory, self._path ) except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): raise raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc return proto class NamedPipeConnector(BaseConnector): """Named pipe connector. Only supported by the proactor event loop. See also: https://docs.python.org/3/library/asyncio-eventloop.html path - Windows named pipe path. keepalive_timeout - (optional) Keep-alive timeout. force_close - Set to True to force close and do reconnect after each request (and between redirects). limit - The total number of simultaneous connections. limit_per_host - Number of simultaneous connections to one host. loop - Optional event loop. """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"}) def __init__( self, path: str, force_close: bool = False, keepalive_timeout: _SENTINEL | float | None = sentinel, limit: int = 100, limit_per_host: int = 0, ) -> None: super().__init__( force_close=force_close, keepalive_timeout=keepalive_timeout, limit=limit, limit_per_host=limit_per_host, ) if not isinstance( self._loop, asyncio.ProactorEventLoop, # type: ignore[attr-defined] ): raise RuntimeError( "Named Pipes only available in proactor loop under windows" ) self._path = path @property def path(self) -> str: """Path to the named pipe.""" return self._path async def _create_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): _, proto = await self._loop.create_pipe_connection( # type: ignore[attr-defined] self._factory, self._path ) # the drain is required so that the connection_made is called # and transport is set otherwise it is not set before the # `assert conn.transport is not None` # in client.py's _request method await asyncio.sleep(0) # other option is to manually set transport like # `proto.transport = trans` except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): raise raise ClientConnectorError(req.connection_key, exc) from exc return cast(ResponseHandler, proto) ================================================ FILE: aiohttp/cookiejar.py ================================================ import calendar import contextlib import datetime import heapq import itertools import json import os # noqa import pathlib import pickle import re import time import warnings from collections import defaultdict from collections.abc import Iterable, Iterator, Mapping from http.cookies import BaseCookie, Morsel, SimpleCookie from typing import Union from yarl import URL from ._cookie_helpers import preserve_morsel_with_coded_value from .abc import AbstractCookieJar, ClearCookiePredicate from .helpers import is_ip_address from .typedefs import LooseCookies, PathLike, StrOrURL __all__ = ("CookieJar", "DummyCookieJar") CookieItem = Union[str, "Morsel[str]"] # We cache these string methods here as their use is in performance critical code. _FORMAT_PATH = "{}/{}".format _FORMAT_DOMAIN_REVERSED = "{1}.{0}".format # The minimum number of scheduled cookie expirations before we start cleaning up # the expiration heap. This is a performance optimization to avoid cleaning up the # heap too often when there are only a few scheduled expirations. _MIN_SCHEDULED_COOKIE_EXPIRATION = 100 _SIMPLE_COOKIE = SimpleCookie() class _RestrictedCookieUnpickler(pickle.Unpickler): """A restricted unpickler that only allows cookie-related types. This prevents arbitrary code execution when loading pickled cookie data from untrusted sources. Only types that are expected in a serialized CookieJar are permitted. See: https://docs.python.org/3/library/pickle.html#restricting-globals """ _ALLOWED_CLASSES: frozenset[tuple[str, str]] = frozenset( { # Core cookie types ("http.cookies", "SimpleCookie"), ("http.cookies", "Morsel"), # Container types used by CookieJar._cookies ("collections", "defaultdict"), # builtins that pickle uses for reconstruction ("builtins", "tuple"), ("builtins", "set"), ("builtins", "frozenset"), ("builtins", "dict"), } ) def find_class(self, module: str, name: str) -> type: if (module, name) not in self._ALLOWED_CLASSES: raise pickle.UnpicklingError( f"Forbidden class: {module}.{name}. " "CookieJar.load() only allows cookie-related types for security. " "See https://docs.python.org/3/library/pickle.html#restricting-globals" ) return super().find_class(module, name) # type: ignore[no-any-return] class CookieJar(AbstractCookieJar): """Implements cookie storage adhering to RFC 6265.""" DATE_TOKENS_RE = re.compile( r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" r"(?P[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" ) DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") DATE_MONTH_RE = re.compile( "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)", re.I, ) DATE_YEAR_RE = re.compile(r"(\d{2,4})") # calendar.timegm() fails for timestamps after datetime.datetime.max # Minus one as a loss of precision occurs when timestamp() is called. MAX_TIME = ( int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1 ) try: calendar.timegm(time.gmtime(MAX_TIME)) except OSError: # Hit the maximum representable time on Windows # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64 MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1)) except OverflowError: # #4515: datetime.max may not be representable on 32-bit platforms MAX_TIME = 2**31 - 1 # Avoid minuses in the future, 3x faster SUB_MAX_TIME = MAX_TIME - 1 def __init__( self, *, unsafe: bool = False, quote_cookie: bool = True, treat_as_secure_origin: StrOrURL | Iterable[StrOrURL] | None = None, ) -> None: self._cookies: defaultdict[tuple[str, str], SimpleCookie] = defaultdict( SimpleCookie ) self._morsel_cache: defaultdict[tuple[str, str], dict[str, Morsel[str]]] = ( defaultdict(dict) ) self._host_only_cookies: set[tuple[str, str]] = set() self._unsafe = unsafe self._quote_cookie = quote_cookie if treat_as_secure_origin is None: self._treat_as_secure_origin: frozenset[URL] = frozenset() elif isinstance(treat_as_secure_origin, URL): self._treat_as_secure_origin = frozenset({treat_as_secure_origin.origin()}) elif isinstance(treat_as_secure_origin, str): self._treat_as_secure_origin = frozenset( {URL(treat_as_secure_origin).origin()} ) else: self._treat_as_secure_origin = frozenset( { URL(url).origin() if isinstance(url, str) else url.origin() for url in treat_as_secure_origin } ) self._expire_heap: list[tuple[float, tuple[str, str, str]]] = [] self._expirations: dict[tuple[str, str, str], float] = {} @property def quote_cookie(self) -> bool: return self._quote_cookie def save(self, file_path: PathLike) -> None: """Save cookies to a file using JSON format. :param file_path: Path to file where cookies will be serialized, :class:`str` or :class:`pathlib.Path` instance. """ file_path = pathlib.Path(file_path) data: dict[str, dict[str, dict[str, str | bool]]] = {} for (domain, path), cookie in self._cookies.items(): key = f"{domain}|{path}" data[key] = {} for name, morsel in cookie.items(): morsel_data: dict[str, str | bool] = { "key": morsel.key, "value": morsel.value, "coded_value": morsel.coded_value, } # Save all morsel attributes that have values for attr in morsel._reserved: # type: ignore[attr-defined] attr_val = morsel[attr] if attr_val: morsel_data[attr] = attr_val data[key][name] = morsel_data with file_path.open(mode="w", encoding="utf-8") as f: json.dump(data, f, indent=2) def load(self, file_path: PathLike) -> None: """Load cookies from a file. Tries to load JSON format first. Falls back to loading legacy pickle format (using a restricted unpickler) for backward compatibility with existing cookie files. :param file_path: Path to file from where cookies will be imported, :class:`str` or :class:`pathlib.Path` instance. """ file_path = pathlib.Path(file_path) # Try JSON format first try: with file_path.open(mode="r", encoding="utf-8") as f: data = json.load(f) self._cookies = self._load_json_data(data) except (json.JSONDecodeError, UnicodeDecodeError, ValueError): # Fall back to legacy pickle format with restricted unpickler with file_path.open(mode="rb") as f: self._cookies = _RestrictedCookieUnpickler(f).load() def _load_json_data( self, data: dict[str, dict[str, dict[str, str | bool]]] ) -> defaultdict[tuple[str, str], SimpleCookie]: """Load cookies from parsed JSON data.""" cookies: defaultdict[tuple[str, str], SimpleCookie] = defaultdict(SimpleCookie) for compound_key, cookie_data in data.items(): domain, path = compound_key.split("|", 1) key = (domain, path) for name, morsel_data in cookie_data.items(): morsel: Morsel[str] = Morsel() morsel_key = morsel_data["key"] morsel_value = morsel_data["value"] morsel_coded_value = morsel_data["coded_value"] # Use __setstate__ to bypass validation, same pattern # used in _build_morsel and _cookie_helpers. morsel.__setstate__( # type: ignore[attr-defined] { "key": morsel_key, "value": morsel_value, "coded_value": morsel_coded_value, } ) # Restore morsel attributes for attr in morsel._reserved: # type: ignore[attr-defined] if attr in morsel_data and attr not in ( "key", "value", "coded_value", ): morsel[attr] = morsel_data[attr] cookies[key][name] = morsel return cookies def clear(self, predicate: ClearCookiePredicate | None = None) -> None: if predicate is None: self._expire_heap.clear() self._cookies.clear() self._morsel_cache.clear() self._host_only_cookies.clear() self._expirations.clear() return now = time.time() to_del = [ key for (domain, path), cookie in self._cookies.items() for name, morsel in cookie.items() if ( (key := (domain, path, name)) in self._expirations and self._expirations[key] <= now ) or predicate(morsel) ] if to_del: self._delete_cookies(to_del) def clear_domain(self, domain: str) -> None: self.clear(lambda x: self._is_domain_match(domain, x["domain"])) def __iter__(self) -> "Iterator[Morsel[str]]": self._do_expiration() for val in self._cookies.values(): yield from val.values() def __len__(self) -> int: """Return number of cookies. This function does not iterate self to avoid unnecessary expiration checks. """ return sum(len(cookie.values()) for cookie in self._cookies.values()) def _do_expiration(self) -> None: """Remove expired cookies.""" if not (expire_heap_len := len(self._expire_heap)): return # If the expiration heap grows larger than the number expirations # times two, we clean it up to avoid keeping expired entries in # the heap and consuming memory. We guard this with a minimum # threshold to avoid cleaning up the heap too often when there are # only a few scheduled expirations. if ( expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION and expire_heap_len > len(self._expirations) * 2 ): # Remove any expired entries from the expiration heap # that do not match the expiration time in the expirations # as it means the cookie has been re-added to the heap # with a different expiration time. self._expire_heap = [ entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0] ] heapq.heapify(self._expire_heap) now = time.time() to_del: list[tuple[str, str, str]] = [] # Find any expired cookies and add them to the to-delete list while self._expire_heap: when, cookie_key = self._expire_heap[0] if when > now: break heapq.heappop(self._expire_heap) # Check if the cookie hasn't been re-added to the heap # with a different expiration time as it will be removed # later when it reaches the top of the heap and its # expiration time is met. if self._expirations.get(cookie_key) == when: to_del.append(cookie_key) if to_del: self._delete_cookies(to_del) def _delete_cookies(self, to_del: list[tuple[str, str, str]]) -> None: for domain, path, name in to_del: self._host_only_cookies.discard((domain, name)) self._cookies[(domain, path)].pop(name, None) self._morsel_cache[(domain, path)].pop(name, None) self._expirations.pop((domain, path, name), None) def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None: cookie_key = (domain, path, name) if self._expirations.get(cookie_key) == when: # Avoid adding duplicates to the heap return heapq.heappush(self._expire_heap, (when, cookie_key)) self._expirations[cookie_key] = when def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: """Update cookies.""" hostname = response_url.raw_host if not self._unsafe and is_ip_address(hostname): # Don't accept cookies from IPs return if isinstance(cookies, Mapping): cookies = cookies.items() for name, cookie in cookies: if not isinstance(cookie, Morsel): tmp = SimpleCookie() tmp[name] = cookie # type: ignore[assignment] cookie = tmp[name] domain = cookie["domain"] # ignore domains with trailing dots if domain and domain[-1] == ".": domain = "" del cookie["domain"] if not domain and hostname is not None: # Set the cookie's domain to the response hostname # and set its host-only-flag self._host_only_cookies.add((hostname, name)) domain = cookie["domain"] = hostname if domain and domain[0] == ".": # Remove leading dot domain = domain[1:] cookie["domain"] = domain if hostname and not self._is_domain_match(domain, hostname): # Setting cookies for different domains is not allowed continue path = cookie["path"] if not path or path[0] != "/": # Set the cookie's path to the response path path = response_url.path if not path.startswith("/"): path = "/" else: # Cut everything from the last slash to the end path = "/" + path[1 : path.rfind("/")] cookie["path"] = path path = path.rstrip("/") if max_age := cookie["max-age"]: try: delta_seconds = int(max_age) max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME) self._expire_cookie(max_age_expiration, domain, path, name) except ValueError: cookie["max-age"] = "" elif expires := cookie["expires"]: if expire_time := self._parse_date(expires): self._expire_cookie(expire_time, domain, path, name) else: cookie["expires"] = "" key = (domain, path) if self._cookies[key].get(name) != cookie: # Don't blow away the cache if the same # cookie gets set again self._cookies[key][name] = cookie self._morsel_cache[key].pop(name, None) self._do_expiration() def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": """Returns this jar's cookies filtered by their attributes.""" if not isinstance(request_url, URL): warnings.warn( # type: ignore[unreachable] f"The method accepts yarl.URL instances only, got {type(request_url)}", DeprecationWarning, ) request_url = URL(request_url) # We always use BaseCookie now since all # cookies set on on filtered are fully constructed # Morsels, not just names and values. filtered: BaseCookie[str] = BaseCookie() if not self._cookies: # Skip do_expiration() if there are no cookies. return filtered self._do_expiration() if not self._cookies: # Skip rest of function if no non-expired cookies. return filtered hostname = request_url.raw_host or "" is_not_secure = request_url.scheme not in ("https", "wss") if is_not_secure and self._treat_as_secure_origin: request_origin = URL() with contextlib.suppress(ValueError): request_origin = request_url.origin() is_not_secure = request_origin not in self._treat_as_secure_origin # Send shared cookie key = ("", "") for c in self._cookies[key].values(): # Check cache first if c.key in self._morsel_cache[key]: filtered[c.key] = self._morsel_cache[key][c.key] continue # Build and cache the morsel mrsl_val = self._build_morsel(c) self._morsel_cache[key][c.key] = mrsl_val filtered[c.key] = mrsl_val if is_ip_address(hostname): if not self._unsafe: return filtered domains: Iterable[str] = (hostname,) else: # Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com") domains = itertools.accumulate( reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED ) # Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar") paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH) # Create every combination of (domain, path) pairs. pairs = itertools.product(domains, paths) path_len = len(request_url.path) # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 for p in pairs: if p not in self._cookies: continue for name, cookie in self._cookies[p].items(): domain = cookie["domain"] if (domain, name) in self._host_only_cookies and domain != hostname: continue # Skip edge case when the cookie has a trailing slash but request doesn't. if len(cookie["path"]) > path_len: continue if is_not_secure and cookie["secure"]: continue # We already built the Morsel so reuse it here if name in self._morsel_cache[p]: filtered[name] = self._morsel_cache[p][name] continue # Build and cache the morsel mrsl_val = self._build_morsel(cookie) self._morsel_cache[p][name] = mrsl_val filtered[name] = mrsl_val return filtered def _build_morsel(self, cookie: Morsel[str]) -> Morsel[str]: """Build a morsel for sending, respecting quote_cookie setting.""" if self._quote_cookie and cookie.coded_value and cookie.coded_value[0] == '"': return preserve_morsel_with_coded_value(cookie) morsel: Morsel[str] = Morsel() if self._quote_cookie: value, coded_value = _SIMPLE_COOKIE.value_encode(cookie.value) else: coded_value = value = cookie.value # We use __setstate__ instead of the public set() API because it allows us to # bypass validation and set already validated state. This is more stable than # setting protected attributes directly and unlikely to change since it would # break pickling. morsel.__setstate__({"key": cookie.key, "value": value, "coded_value": coded_value}) # type: ignore[attr-defined] return morsel @staticmethod def _is_domain_match(domain: str, hostname: str) -> bool: """Implements domain matching adhering to RFC 6265.""" if hostname == domain: return True if not hostname.endswith(domain): return False non_matching = hostname[: -len(domain)] if not non_matching.endswith("."): return False return not is_ip_address(hostname) @classmethod def _parse_date(cls, date_str: str) -> int | None: """Implements date string parsing adhering to RFC 6265.""" if not date_str: return None found_time = False found_day = False found_month = False found_year = False hour = minute = second = 0 day = 0 month = 0 year = 0 for token_match in cls.DATE_TOKENS_RE.finditer(date_str): token = token_match.group("token") if not found_time: time_match = cls.DATE_HMS_TIME_RE.match(token) if time_match: found_time = True hour, minute, second = (int(s) for s in time_match.groups()) continue if not found_day: day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) if day_match: found_day = True day = int(day_match.group()) continue if not found_month: month_match = cls.DATE_MONTH_RE.match(token) if month_match: found_month = True assert month_match.lastindex is not None month = month_match.lastindex continue if not found_year: year_match = cls.DATE_YEAR_RE.match(token) if year_match: found_year = True year = int(year_match.group()) if 70 <= year <= 99: year += 1900 elif 0 <= year <= 69: year += 2000 if False in (found_day, found_month, found_year, found_time): return None if not 1 <= day <= 31: return None if year < 1601 or hour > 23 or minute > 59 or second > 59: return None return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1)) class DummyCookieJar(AbstractCookieJar): """Implements a dummy cookie storage. It can be used with the ClientSession when no cookie processing is needed. """ def __iter__(self) -> "Iterator[Morsel[str]]": while False: yield None # type: ignore[unreachable] def __len__(self) -> int: return 0 @property def quote_cookie(self) -> bool: return True def clear(self, predicate: ClearCookiePredicate | None = None) -> None: pass def clear_domain(self, domain: str) -> None: pass def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: pass def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": return SimpleCookie() ================================================ FILE: aiohttp/formdata.py ================================================ import io from collections import deque from collections.abc import Iterable from typing import Any from urllib.parse import urlencode from multidict import MultiDict, MultiDictProxy from . import hdrs, multipart, payload from .helpers import guess_filename from .payload import Payload __all__ = ("FormData",) class FormData: """Helper class for form body generation. Supports multipart/form-data and application/x-www-form-urlencoded. """ def __init__( self, fields: Iterable[Any] = (), quote_fields: bool = True, charset: str | None = None, boundary: str | None = None, *, default_to_multipart: bool = False, ) -> None: self._boundary = boundary self._writer = multipart.MultipartWriter("form-data", boundary=self._boundary) self._fields: list[Any] = [] self._is_multipart = default_to_multipart self._quote_fields = quote_fields self._charset = charset if isinstance(fields, dict): fields = list(fields.items()) elif not isinstance(fields, (list, tuple)): fields = (fields,) self.add_fields(*fields) @property def is_multipart(self) -> bool: return self._is_multipart def add_field( self, name: str, value: Any, *, content_type: str | None = None, filename: str | None = None, ) -> None: if isinstance(value, (io.IOBase, bytes, bytearray, memoryview)): self._is_multipart = True type_options: MultiDict[str] = MultiDict({"name": name}) if filename is not None and not isinstance(filename, str): raise TypeError("filename must be an instance of str. Got: %s" % filename) if filename is None and isinstance(value, io.IOBase): filename = guess_filename(value, name) if filename is not None: type_options["filename"] = filename self._is_multipart = True headers = {} if content_type is not None: if not isinstance(content_type, str): raise TypeError( "content_type must be an instance of str. Got: %s" % content_type ) if "\r" in content_type or "\n" in content_type: raise ValueError( "Newline or carriage return detected in headers. " "Potential header injection attack." ) headers[hdrs.CONTENT_TYPE] = content_type self._is_multipart = True self._fields.append((type_options, headers, value)) def add_fields(self, *fields: Any) -> None: to_add: deque[Any] = deque(fields) while to_add: rec = to_add.popleft() if isinstance(rec, io.IOBase): k = guess_filename(rec, "unknown") self.add_field(k, rec) # type: ignore[arg-type] elif isinstance(rec, (MultiDictProxy, MultiDict)): to_add.extend(rec.items()) elif isinstance(rec, (list, tuple)) and len(rec) == 2: k, fp = rec self.add_field(k, fp) else: raise TypeError( "Only io.IOBase, multidict and (name, file) " "pairs allowed, use .add_field() for passing " f"more complex parameters, got {rec!r}" ) def _gen_form_urlencoded(self) -> payload.BytesPayload: # form data (x-www-form-urlencoded) data = [] for type_options, _, value in self._fields: if not isinstance(value, str): raise TypeError(f"expected str, got {value!r}") data.append((type_options["name"], value)) charset = self._charset if self._charset is not None else "utf-8" if charset == "utf-8": content_type = "application/x-www-form-urlencoded" else: content_type = "application/x-www-form-urlencoded; charset=%s" % charset return payload.BytesPayload( urlencode(data, doseq=True, encoding=charset).encode(), content_type=content_type, ) def _gen_form_data(self) -> multipart.MultipartWriter: """Encode a list of fields using the multipart/form-data MIME format""" for dispparams, headers, value in self._fields: try: if hdrs.CONTENT_TYPE in headers: part = payload.get_payload( value, content_type=headers[hdrs.CONTENT_TYPE], headers=headers, encoding=self._charset, ) else: part = payload.get_payload( value, headers=headers, encoding=self._charset ) except Exception as exc: raise TypeError( "Can not serialize value type: %r\n " "headers: %r\n value: %r" % (type(value), headers, value) ) from exc if dispparams: part.set_content_disposition( "form-data", quote_fields=self._quote_fields, **dispparams ) # FIXME cgi.FieldStorage doesn't likes body parts with # Content-Length which were sent via chunked transfer encoding assert part.headers is not None part.headers.popall(hdrs.CONTENT_LENGTH, None) self._writer.append_payload(part) self._fields.clear() return self._writer def __call__(self) -> Payload: if self._is_multipart: return self._gen_form_data() else: return self._gen_form_urlencoded() ================================================ FILE: aiohttp/hdrs.py ================================================ """HTTP Headers constants.""" # After changing the file content call ./tools/gen.py # to regenerate the headers parser import itertools from typing import Final from multidict import istr METH_ANY: Final[str] = "*" METH_CONNECT: Final[str] = "CONNECT" METH_HEAD: Final[str] = "HEAD" METH_GET: Final[str] = "GET" METH_DELETE: Final[str] = "DELETE" METH_OPTIONS: Final[str] = "OPTIONS" METH_PATCH: Final[str] = "PATCH" METH_POST: Final[str] = "POST" METH_PUT: Final[str] = "PUT" METH_TRACE: Final[str] = "TRACE" METH_ALL: Final[set[str]] = { METH_CONNECT, METH_HEAD, METH_GET, METH_DELETE, METH_OPTIONS, METH_PATCH, METH_POST, METH_PUT, METH_TRACE, } ACCEPT: Final[istr] = istr("Accept") ACCEPT_CHARSET: Final[istr] = istr("Accept-Charset") ACCEPT_ENCODING: Final[istr] = istr("Accept-Encoding") ACCEPT_LANGUAGE: Final[istr] = istr("Accept-Language") ACCEPT_RANGES: Final[istr] = istr("Accept-Ranges") ACCESS_CONTROL_MAX_AGE: Final[istr] = istr("Access-Control-Max-Age") ACCESS_CONTROL_ALLOW_CREDENTIALS: Final[istr] = istr("Access-Control-Allow-Credentials") ACCESS_CONTROL_ALLOW_HEADERS: Final[istr] = istr("Access-Control-Allow-Headers") ACCESS_CONTROL_ALLOW_METHODS: Final[istr] = istr("Access-Control-Allow-Methods") ACCESS_CONTROL_ALLOW_ORIGIN: Final[istr] = istr("Access-Control-Allow-Origin") ACCESS_CONTROL_EXPOSE_HEADERS: Final[istr] = istr("Access-Control-Expose-Headers") ACCESS_CONTROL_REQUEST_HEADERS: Final[istr] = istr("Access-Control-Request-Headers") ACCESS_CONTROL_REQUEST_METHOD: Final[istr] = istr("Access-Control-Request-Method") AGE: Final[istr] = istr("Age") ALLOW: Final[istr] = istr("Allow") AUTHORIZATION: Final[istr] = istr("Authorization") CACHE_CONTROL: Final[istr] = istr("Cache-Control") CONNECTION: Final[istr] = istr("Connection") CONTENT_DISPOSITION: Final[istr] = istr("Content-Disposition") CONTENT_ENCODING: Final[istr] = istr("Content-Encoding") CONTENT_LANGUAGE: Final[istr] = istr("Content-Language") CONTENT_LENGTH: Final[istr] = istr("Content-Length") CONTENT_LOCATION: Final[istr] = istr("Content-Location") CONTENT_MD5: Final[istr] = istr("Content-MD5") CONTENT_RANGE: Final[istr] = istr("Content-Range") CONTENT_TRANSFER_ENCODING: Final[istr] = istr("Content-Transfer-Encoding") CONTENT_TYPE: Final[istr] = istr("Content-Type") COOKIE: Final[istr] = istr("Cookie") DATE: Final[istr] = istr("Date") DESTINATION: Final[istr] = istr("Destination") DIGEST: Final[istr] = istr("Digest") ETAG: Final[istr] = istr("Etag") EXPECT: Final[istr] = istr("Expect") EXPIRES: Final[istr] = istr("Expires") FORWARDED: Final[istr] = istr("Forwarded") FROM: Final[istr] = istr("From") HOST: Final[istr] = istr("Host") IF_MATCH: Final[istr] = istr("If-Match") IF_MODIFIED_SINCE: Final[istr] = istr("If-Modified-Since") IF_NONE_MATCH: Final[istr] = istr("If-None-Match") IF_RANGE: Final[istr] = istr("If-Range") IF_UNMODIFIED_SINCE: Final[istr] = istr("If-Unmodified-Since") KEEP_ALIVE: Final[istr] = istr("Keep-Alive") LAST_EVENT_ID: Final[istr] = istr("Last-Event-ID") LAST_MODIFIED: Final[istr] = istr("Last-Modified") LINK: Final[istr] = istr("Link") LOCATION: Final[istr] = istr("Location") MAX_FORWARDS: Final[istr] = istr("Max-Forwards") ORIGIN: Final[istr] = istr("Origin") PRAGMA: Final[istr] = istr("Pragma") PROXY_AUTHENTICATE: Final[istr] = istr("Proxy-Authenticate") PROXY_AUTHORIZATION: Final[istr] = istr("Proxy-Authorization") RANGE: Final[istr] = istr("Range") REFERER: Final[istr] = istr("Referer") RETRY_AFTER: Final[istr] = istr("Retry-After") SEC_WEBSOCKET_ACCEPT: Final[istr] = istr("Sec-WebSocket-Accept") SEC_WEBSOCKET_VERSION: Final[istr] = istr("Sec-WebSocket-Version") SEC_WEBSOCKET_PROTOCOL: Final[istr] = istr("Sec-WebSocket-Protocol") SEC_WEBSOCKET_EXTENSIONS: Final[istr] = istr("Sec-WebSocket-Extensions") SEC_WEBSOCKET_KEY: Final[istr] = istr("Sec-WebSocket-Key") SEC_WEBSOCKET_KEY1: Final[istr] = istr("Sec-WebSocket-Key1") SERVER: Final[istr] = istr("Server") SET_COOKIE: Final[istr] = istr("Set-Cookie") TE: Final[istr] = istr("TE") TRAILER: Final[istr] = istr("Trailer") TRANSFER_ENCODING: Final[istr] = istr("Transfer-Encoding") UPGRADE: Final[istr] = istr("Upgrade") URI: Final[istr] = istr("URI") USER_AGENT: Final[istr] = istr("User-Agent") VARY: Final[istr] = istr("Vary") VIA: Final[istr] = istr("Via") WANT_DIGEST: Final[istr] = istr("Want-Digest") WARNING: Final[istr] = istr("Warning") WWW_AUTHENTICATE: Final[istr] = istr("WWW-Authenticate") X_FORWARDED_FOR: Final[istr] = istr("X-Forwarded-For") X_FORWARDED_HOST: Final[istr] = istr("X-Forwarded-Host") X_FORWARDED_PROTO: Final[istr] = istr("X-Forwarded-Proto") # These are the upper/lower case variants of the headers/methods # Example: {'hOst', 'host', 'HoST', 'HOSt', 'hOsT', 'HosT', 'hoSt', ...} METH_HEAD_ALL: Final = frozenset( map("".join, itertools.product(*zip(METH_HEAD.upper(), METH_HEAD.lower()))) ) METH_CONNECT_ALL: Final = frozenset( map("".join, itertools.product(*zip(METH_CONNECT.upper(), METH_CONNECT.lower()))) ) HOST_ALL: Final = frozenset( map("".join, itertools.product(*zip(HOST.upper(), HOST.lower()))) ) ================================================ FILE: aiohttp/helpers.py ================================================ """Various helper functions""" import asyncio import base64 import binascii import contextlib import dataclasses import datetime import enum import functools import inspect import netrc import os import platform import re import sys import time import warnings import weakref from collections import namedtuple from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import suppress from email.message import EmailMessage from email.parser import HeaderParser from email.policy import HTTP from email.utils import parsedate from http.cookies import SimpleCookie from math import ceil from pathlib import Path from types import MappingProxyType, TracebackType from typing import ( TYPE_CHECKING, Any, ContextManager, Generic, Optional, Protocol, TypeVar, Union, final, get_args, overload, ) from urllib.parse import quote from urllib.request import getproxies, proxy_bypass from multidict import CIMultiDict, MultiDict, MultiDictProxy, MultiMapping from propcache.api import under_cached_property as reify from yarl import URL from . import hdrs from .log import client_logger from .typedefs import PathLike # noqa if sys.version_info >= (3, 11): import asyncio as async_timeout else: import async_timeout if TYPE_CHECKING: from dataclasses import dataclass as frozen_dataclass_decorator else: frozen_dataclass_decorator = functools.partial( dataclasses.dataclass, frozen=True, slots=True ) __all__ = ("BasicAuth", "ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify") COOKIE_MAX_LENGTH = 4096 _T = TypeVar("_T") _S = TypeVar("_S") _SENTINEL = enum.Enum("_SENTINEL", "sentinel") sentinel = _SENTINEL.sentinel NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200))) # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2 EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL DEBUG = sys.flags.dev_mode or ( not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) ) CHAR = {chr(i) for i in range(0, 128)} CTL = {chr(i) for i in range(0, 32)} | { chr(127), } SEPARATORS = { "(", ")", "<", ">", "@", ",", ";", ":", "\\", '"', "/", "[", "]", "?", "=", "{", "}", " ", chr(9), } TOKEN = CHAR ^ CTL ^ SEPARATORS json_re = re.compile(r"^(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE) class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): """Http basic authentication helper.""" def __new__( cls, login: str, password: str = "", encoding: str = "latin1" ) -> "BasicAuth": if login is None: raise ValueError("None is not allowed as login value") if password is None: raise ValueError("None is not allowed as password value") if ":" in login: raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') return super().__new__(cls, login, password, encoding) @classmethod def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": """Create a BasicAuth object from an Authorization HTTP header.""" try: auth_type, encoded_credentials = auth_header.split(" ", 1) except ValueError: raise ValueError("Could not parse authorization header.") if auth_type.lower() != "basic": raise ValueError("Unknown authorization method %s" % auth_type) try: decoded = base64.b64decode( encoded_credentials.encode("ascii"), validate=True ).decode(encoding) except binascii.Error: raise ValueError("Invalid base64 encoding.") try: # RFC 2617 HTTP Authentication # https://www.ietf.org/rfc/rfc2617.txt # the colon must be present, but the username and password may be # otherwise blank. username, password = decoded.split(":", 1) except ValueError: raise ValueError("Invalid credentials.") return cls(username, password, encoding=encoding) @classmethod def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: """Create BasicAuth from url.""" if not isinstance(url, URL): raise TypeError("url should be yarl.URL instance") # Check raw_user and raw_password first as yarl is likely # to already have these values parsed from the netloc in the cache. if url.raw_user is None and url.raw_password is None: return None return cls(url.user or "", url.password or "", encoding=encoding) def encode(self) -> str: """Encode credentials.""" creds = (f"{self.login}:{self.password}").encode(self.encoding) return "Basic %s" % base64.b64encode(creds).decode(self.encoding) def strip_auth_from_url(url: URL) -> tuple[URL, BasicAuth | None]: """Remove user and password from URL if present and return BasicAuth object.""" # Check raw_user and raw_password first as yarl is likely # to already have these values parsed from the netloc in the cache. if url.raw_user is None and url.raw_password is None: return url, None return url.with_user(None), BasicAuth(url.user or "", url.password or "") def netrc_from_env() -> netrc.netrc | None: """Load netrc from file. Attempt to load it from the path specified by the env-var NETRC or in the default location in the user's home directory. Returns None if it couldn't be found or fails to parse. """ netrc_env = os.environ.get("NETRC") if netrc_env is not None: netrc_path = Path(netrc_env) else: try: home_dir = Path.home() except RuntimeError as e: # if pathlib can't resolve home, it may raise a RuntimeError client_logger.debug( "Could not resolve home directory when " "trying to look for .netrc file: %s", e, ) return None netrc_path = home_dir / ( "_netrc" if platform.system() == "Windows" else ".netrc" ) try: return netrc.netrc(str(netrc_path)) except netrc.NetrcParseError as e: client_logger.warning("Could not parse .netrc file: %s", e) except OSError as e: netrc_exists = False with contextlib.suppress(OSError): netrc_exists = netrc_path.is_file() # we couldn't read the file (doesn't exist, permissions, etc.) if netrc_env or netrc_exists: # only warn if the environment wanted us to load it, # or it appears like the default file does actually exist client_logger.warning("Could not read .netrc file: %s", e) return None @frozen_dataclass_decorator class ProxyInfo: proxy: URL proxy_auth: BasicAuth | None def basicauth_from_netrc(netrc_obj: netrc.netrc | None, host: str) -> BasicAuth: """ Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``. :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no entry is found for the ``host``. """ if netrc_obj is None: raise LookupError("No .netrc file found") auth_from_netrc = netrc_obj.authenticators(host) if auth_from_netrc is None: raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") login, account, password = auth_from_netrc # TODO(PY311): username = login or account # Up to python 3.10, account could be None if not specified, # and login will be empty string if not specified. From 3.11, # login and account will be empty string if not specified. username = login if (login or account is None) else account # TODO(PY311): Remove this, as password will be empty string # if not specified if password is None: password = "" # type: ignore[unreachable] return BasicAuth(username, password) def proxies_from_env() -> dict[str, ProxyInfo]: proxy_urls = { k: URL(v) for k, v in getproxies().items() if k in ("http", "https", "ws", "wss") } netrc_obj = netrc_from_env() stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} ret = {} for proto, val in stripped.items(): proxy, auth = val if proxy.scheme in ("https", "wss"): client_logger.warning( "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy ) continue if netrc_obj and auth is None: if proxy.host is not None: try: auth = basicauth_from_netrc(netrc_obj, proxy.host) except LookupError: auth = None ret[proto] = ProxyInfo(proxy, auth) return ret def get_env_proxy_for_url(url: URL) -> tuple[URL, BasicAuth | None]: """Get a permitted proxy for the given URL from the env.""" if url.host is not None and proxy_bypass(url.host): raise LookupError(f"Proxying is disallowed for `{url.host!r}`") proxies_in_env = proxies_from_env() try: proxy_info = proxies_in_env[url.scheme] except KeyError: raise LookupError(f"No proxies found for `{url!s}` in the env") else: return proxy_info.proxy, proxy_info.proxy_auth @frozen_dataclass_decorator class MimeType: type: str subtype: str suffix: str parameters: "MultiDictProxy[str]" @functools.lru_cache(maxsize=56) def parse_mimetype(mimetype: str) -> MimeType: """Parses a MIME type into its components. mimetype is a MIME type string. Returns a MimeType object. Example: >>> parse_mimetype('text/html; charset=utf-8') MimeType(type='text', subtype='html', suffix='', parameters={'charset': 'utf-8'}) """ if not mimetype: return MimeType( type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) ) parts = mimetype.split(";") params: MultiDict[str] = MultiDict() for item in parts[1:]: if not item: continue key, _, value = item.partition("=") params.add(key.lower().strip(), value.strip(' "')) fulltype = parts[0].strip().lower() if fulltype == "*": fulltype = "*/*" mtype, _, stype = fulltype.partition("/") stype, _, suffix = stype.partition("+") return MimeType( type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) ) class EnsureOctetStream(EmailMessage): def __init__(self) -> None: super().__init__() # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5 self.set_default_type("application/octet-stream") def get_content_type(self) -> str: """Re-implementation from Message Returns application/octet-stream in place of plain/text when value is wrong. The way this class is used guarantees that content-type will be present so simplify the checks wrt to the base implementation. """ value = self.get("content-type", "").lower() # Based on the implementation of _splitparam in the standard library ctype, _, _ = value.partition(";") ctype = ctype.strip() if ctype.count("/") != 1: return self.get_default_type() return ctype @functools.lru_cache(maxsize=56) def parse_content_type(raw: str) -> tuple[str, MappingProxyType[str, str]]: """Parse Content-Type header. Returns a tuple of the parsed content type and a MappingProxyType of parameters. The default returned value is `application/octet-stream` """ msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}") content_type = msg.get_content_type() params = msg.get_params(()) content_dict = dict(params[1:]) # First element is content type again return content_type, MappingProxyType(content_dict) def guess_filename(obj: Any, default: str | None = None) -> str | None: name = getattr(obj, "name", None) if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": return Path(name).name return default not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} def quoted_string(content: str) -> str: """Return 7-bit content as quoted-string. Format content into a quoted-string as defined in RFC5322 for Internet Message Format. Notice that this is not the 8-bit HTTP format, but the 7-bit email format. Content must be in usascii or a ValueError is raised. """ if not (QCONTENT > set(content)): raise ValueError(f"bad content for quoted-string {content!r}") return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) def content_disposition_header( disptype: str, quote_fields: bool = True, _charset: str = "utf-8", params: dict[str, str] | None = None, ) -> str: """Sets ``Content-Disposition`` header for MIME. This is the MIME payload Content-Disposition header from RFC 2183 and RFC 7579 section 4.2, not the HTTP Content-Disposition from RFC 6266. disptype is a disposition type: inline, attachment, form-data. Should be valid extension token (see RFC 2183) quote_fields performs value quoting to 7-bit MIME headers according to RFC 7578. Set to quote_fields to False if recipient can take 8-bit file names and field values. _charset specifies the charset to use when quote_fields is True. params is a dict with disposition params. """ if not disptype or not (TOKEN > set(disptype)): raise ValueError(f"bad content disposition type {disptype!r}") value = disptype if params: lparams = [] for key, val in params.items(): if not key or not (TOKEN > set(key)): raise ValueError(f"bad content disposition parameter {key!r}={val!r}") if quote_fields: if key.lower() == "filename": qval = quote(val, "", encoding=_charset) lparams.append((key, '"%s"' % qval)) else: try: qval = quoted_string(val) except ValueError: qval = "".join( (_charset, "''", quote(val, "", encoding=_charset)) ) lparams.append((key + "*", qval)) else: lparams.append((key, '"%s"' % qval)) else: qval = val.replace("\\", "\\\\").replace('"', '\\"') lparams.append((key, '"%s"' % qval)) sparams = "; ".join("=".join(pair) for pair in lparams) value = "; ".join((value, sparams)) return value def is_expected_content_type( response_content_type: str, expected_content_type: str ) -> bool: """Checks if received content type is processable as an expected one. Both arguments should be given without parameters. """ if expected_content_type == "application/json": return json_re.match(response_content_type) is not None return expected_content_type in response_content_type def is_ip_address(host: str | None) -> bool: """Check if host looks like an IP Address. This check is only meant as a heuristic to ensure that a host is not a domain name. """ if not host: return False # For a host to be an ipv4 address, it must be all numeric. # The host must contain a colon to be an IPv6 address. return ":" in host or host.replace(".", "").isdigit() _cached_current_datetime: int | None = None _cached_formatted_datetime = "" def rfc822_formatted_time() -> str: global _cached_current_datetime global _cached_formatted_datetime now = int(time.time()) if now != _cached_current_datetime: # Weekday and month names for HTTP date/time formatting; # always English! # Tuples are constants stored in codeobject! _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") _monthname = ( "", # Dummy so we can use 1-based month numbers "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", ) year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( _weekdayname[wd], day, _monthname[month], year, hh, mm, ss, ) _cached_current_datetime = now return _cached_formatted_datetime def _weakref_handle(info: "tuple[weakref.ref[object], str]") -> None: ref, name = info ob = ref() if ob is not None: with suppress(Exception): getattr(ob, name)() def weakref_handle( ob: object, name: str, timeout: float | None, loop: asyncio.AbstractEventLoop, timeout_ceil_threshold: float = 5, ) -> asyncio.TimerHandle | None: if timeout is not None and timeout > 0: when = loop.time() + timeout if timeout >= timeout_ceil_threshold: when = ceil(when) return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) return None def call_later( cb: Callable[[], Any], timeout: float | None, loop: asyncio.AbstractEventLoop, timeout_ceil_threshold: float = 5, ) -> asyncio.TimerHandle | None: if timeout is None or timeout <= 0: return None now = loop.time() when = calculate_timeout_when(now, timeout, timeout_ceil_threshold) return loop.call_at(when, cb) def calculate_timeout_when( loop_time: float, timeout: float, timeout_ceiling_threshold: float, ) -> float: """Calculate when to execute a timeout.""" when = loop_time + timeout if timeout > timeout_ceiling_threshold: return ceil(when) return when class TimeoutHandle: """Timeout handle""" __slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks") def __init__( self, loop: asyncio.AbstractEventLoop, timeout: float | None, ceil_threshold: float = 5, ) -> None: self._timeout = timeout self._loop = loop self._ceil_threshold = ceil_threshold self._callbacks: list[ tuple[Callable[..., None], tuple[Any, ...], dict[str, Any]] ] = [] def register( self, callback: Callable[..., None], *args: Any, **kwargs: Any ) -> None: self._callbacks.append((callback, args, kwargs)) def close(self) -> None: self._callbacks.clear() def start(self) -> asyncio.TimerHandle | None: timeout = self._timeout if timeout is not None and timeout > 0: when = self._loop.time() + timeout if timeout >= self._ceil_threshold: when = ceil(when) return self._loop.call_at(when, self.__call__) else: return None def timer(self) -> "BaseTimerContext": if self._timeout is not None and self._timeout > 0: timer = TimerContext(self._loop) self.register(timer.timeout) return timer else: return TimerNoop() def __call__(self) -> None: for cb, args, kwargs in self._callbacks: with suppress(Exception): cb(*args, **kwargs) self._callbacks.clear() class BaseTimerContext(ContextManager["BaseTimerContext"]): __slots__ = () def assert_timeout(self) -> None: """Raise TimeoutError if timeout has been exceeded.""" class TimerNoop(BaseTimerContext): __slots__ = () def __enter__(self) -> BaseTimerContext: return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: return class TimerContext(BaseTimerContext): """Low resolution timeout context manager""" __slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling") def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop self._tasks: list[asyncio.Task[Any]] = [] self._cancelled = False self._cancelling = 0 def assert_timeout(self) -> None: """Raise TimeoutError if timer has already been cancelled.""" if self._cancelled: raise asyncio.TimeoutError from None def __enter__(self) -> BaseTimerContext: task = asyncio.current_task(loop=self._loop) if task is None: raise RuntimeError("Timeout context manager should be used inside a task") if sys.version_info >= (3, 11): # Remember if the task was already cancelling # so when we __exit__ we can decide if we should # raise asyncio.TimeoutError or let the cancellation propagate self._cancelling = task.cancelling() if self._cancelled: raise asyncio.TimeoutError from None self._tasks.append(task) return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: enter_task: asyncio.Task[Any] | None = None if self._tasks: enter_task = self._tasks.pop() if exc_type is asyncio.CancelledError and self._cancelled: assert enter_task is not None # The timeout was hit, and the task was cancelled # so we need to uncancel the last task that entered the context manager # since the cancellation should not leak out of the context manager if sys.version_info >= (3, 11): # If the task was already cancelling don't raise # asyncio.TimeoutError and instead return None # to allow the cancellation to propagate if enter_task.uncancel() > self._cancelling: return None raise asyncio.TimeoutError from exc_val return None def timeout(self) -> None: if not self._cancelled: for task in set(self._tasks): task.cancel() self._cancelled = True def ceil_timeout( delay: float | None, ceil_threshold: float = 5 ) -> async_timeout.Timeout: if delay is None or delay <= 0: return async_timeout.timeout(None) loop = asyncio.get_running_loop() now = loop.time() when = now + delay if delay > ceil_threshold: when = ceil(when) return async_timeout.timeout_at(when) class HeadersMixin: """Mixin for handling headers.""" _headers: MultiMapping[str] _content_type: str | None = None _content_dict: dict[str, str] | None = None _stored_content_type: str | None | _SENTINEL = sentinel def _parse_content_type(self, raw: str | None) -> None: self._stored_content_type = raw if raw is None: # default value according to RFC 2616 self._content_type = "application/octet-stream" self._content_dict = {} else: content_type, content_mapping_proxy = parse_content_type(raw) self._content_type = content_type # _content_dict needs to be mutable so we can update it self._content_dict = content_mapping_proxy.copy() @property def content_type(self) -> str: """The value of content part for Content-Type HTTP header.""" raw = self._headers.get(hdrs.CONTENT_TYPE) if self._stored_content_type != raw: self._parse_content_type(raw) assert self._content_type is not None return self._content_type @property def charset(self) -> str | None: """The value of charset part for Content-Type HTTP header.""" raw = self._headers.get(hdrs.CONTENT_TYPE) if self._stored_content_type != raw: self._parse_content_type(raw) assert self._content_dict is not None return self._content_dict.get("charset") @property def content_length(self) -> int | None: """The value of Content-Length HTTP header.""" content_length = self._headers.get(hdrs.CONTENT_LENGTH) return None if content_length is None else int(content_length) def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: if not fut.done(): fut.set_result(result) _EXC_SENTINEL = BaseException() class ErrorableProtocol(Protocol): def set_exception( self, exc: type[BaseException] | BaseException, exc_cause: BaseException = ..., ) -> None: ... def set_exception( fut: Union["asyncio.Future[_T]", ErrorableProtocol], exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: """Set future exception. If the future is marked as complete, this function is a no-op. :param exc_cause: An exception that is a direct cause of ``exc``. Only set if provided. """ if asyncio.isfuture(fut) and fut.done(): return exc_is_sentinel = exc_cause is _EXC_SENTINEL exc_causes_itself = exc is exc_cause if not exc_is_sentinel and not exc_causes_itself: exc.__cause__ = exc_cause fut.set_exception(exc) @functools.total_ordering class BaseKey(Generic[_T]): """Base for concrete context storage key classes. Each storage is provided with its own sub-class for the sake of some additional type safety. """ __slots__ = ("_name", "_t", "__orig_class__") # This may be set by Python when instantiating with a generic type. We need to # support this, in order to support types that are not concrete classes, # like Iterable, which can't be passed as the second parameter to __init__. __orig_class__: type[object] # TODO(PY314): Change Type to TypeForm (this should resolve unreachable below). def __init__(self, name: str, t: type[_T] | None = None): # Prefix with module name to help deduplicate key names. frame = inspect.currentframe() while frame: if frame.f_code.co_name == "": module: str = frame.f_globals["__name__"] break frame = frame.f_back else: raise RuntimeError("Failed to get module name.") # https://github.com/python/mypy/issues/14209 self._name = module + "." + name # type: ignore[possibly-undefined] self._t = t def __lt__(self, other: object) -> bool: if isinstance(other, BaseKey): return self._name < other._name return True # Order BaseKey above other types. def __repr__(self) -> str: t = self._t if t is None: with suppress(AttributeError): # Set to type arg. t = get_args(self.__orig_class__)[0] if t is None: t_repr = "<>" elif isinstance(t, type): if t.__module__ == "builtins": t_repr = t.__qualname__ else: t_repr = f"{t.__module__}.{t.__qualname__}" else: t_repr = repr(t) # type: ignore[unreachable] return f"<{self.__class__.__name__}({self._name}, type={t_repr})>" class AppKey(BaseKey[_T]): """Keys for static typing support in Application.""" class RequestKey(BaseKey[_T]): """Keys for static typing support in Request.""" class ResponseKey(BaseKey[_T]): """Keys for static typing support in Response.""" @final class ChainMapProxy(Mapping[str | AppKey[Any], Any]): __slots__ = ("_maps",) def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None: self._maps = tuple(maps) def __init_subclass__(cls) -> None: raise TypeError( f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden" ) @overload # type: ignore[override] def __getitem__(self, key: AppKey[_T]) -> _T: ... @overload def __getitem__(self, key: str) -> Any: ... def __getitem__(self, key: str | AppKey[_T]) -> Any: for mapping in self._maps: try: return mapping[key] except KeyError: pass raise KeyError(key) @overload # type: ignore[override] def get(self, key: AppKey[_T], default: _S) -> _T | _S: ... @overload def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ... @overload def get(self, key: str, default: Any = ...) -> Any: ... def get(self, key: str | AppKey[_T], default: Any = None) -> Any: try: return self[key] except KeyError: return default def __len__(self) -> int: # reuses stored hash values if possible return len(set().union(*self._maps)) def __iter__(self) -> Iterator[str | AppKey[Any]]: d: dict[str | AppKey[Any], Any] = {} for mapping in reversed(self._maps): # reuses stored hash values if possible d.update(mapping) return iter(d) def __contains__(self, key: object) -> bool: return any(key in m for m in self._maps) def __bool__(self) -> bool: return any(self._maps) def __repr__(self) -> str: content = ", ".join(map(repr, self._maps)) return f"ChainMapProxy({content})" class CookieMixin: """Mixin for handling cookies.""" _cookies: SimpleCookie | None = None @property def cookies(self) -> SimpleCookie: if self._cookies is None: self._cookies = SimpleCookie() return self._cookies def set_cookie( self, name: str, value: str, *, expires: str | None = None, domain: str | None = None, max_age: int | str | None = None, path: str = "/", secure: bool | None = None, httponly: bool | None = None, samesite: str | None = None, partitioned: bool | None = None, ) -> None: """Set or update response cookie. Sets new cookie or updates existent with new value. Also updates only those params which are not None. """ if self._cookies is None: self._cookies = SimpleCookie() self._cookies[name] = value c = self._cookies[name] if expires is not None: c["expires"] = expires elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT": del c["expires"] if domain is not None: c["domain"] = domain if max_age is not None: c["max-age"] = str(max_age) elif "max-age" in c: del c["max-age"] c["path"] = path if secure is not None: c["secure"] = secure if httponly is not None: c["httponly"] = httponly if samesite is not None: c["samesite"] = samesite if partitioned is not None: c["partitioned"] = partitioned if DEBUG: cookie_length = len(c.output(header="")[1:]) if cookie_length > COOKIE_MAX_LENGTH: warnings.warn( "The size of is too large, it might get ignored by the client.", UserWarning, stacklevel=2, ) def del_cookie( self, name: str, *, domain: str | None = None, path: str = "/", secure: bool | None = None, httponly: bool | None = None, samesite: str | None = None, ) -> None: """Delete cookie. Creates new empty expired cookie. """ # TODO: do we need domain/path here? if self._cookies is not None: self._cookies.pop(name, None) self.set_cookie( name, "", max_age=0, expires="Thu, 01 Jan 1970 00:00:00 GMT", domain=domain, path=path, secure=secure, httponly=httponly, samesite=samesite, ) def populate_with_cookies(headers: "CIMultiDict[str]", cookies: SimpleCookie) -> None: for cookie in cookies.values(): value = cookie.output(header="")[1:] headers.add(hdrs.SET_COOKIE, value) # https://tools.ietf.org/html/rfc7232#section-2.3 _ETAGC = r"[!\x23-\x7E\x80-\xff]+" _ETAGC_RE = re.compile(_ETAGC) _QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") ETAG_ANY = "*" @frozen_dataclass_decorator class ETag: value: str is_weak: bool = False def validate_etag_value(value: str) -> None: if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): raise ValueError( f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" ) def parse_http_date(date_str: str | None) -> datetime.datetime | None: """Process a date string, return a datetime object""" if date_str is not None: timetuple = parsedate(date_str) if timetuple is not None: with suppress(ValueError): return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) return None @functools.lru_cache def must_be_empty_body(method: str, code: int) -> bool: """Check if a request must return an empty body.""" return ( code in EMPTY_BODY_STATUS_CODES or method in EMPTY_BODY_METHODS or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL) ) def should_remove_content_length(method: str, code: int) -> bool: """Check if a Content-Length header should be removed. This should always be a subset of must_be_empty_body """ # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8 # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4 return code in EMPTY_BODY_STATUS_CODES or ( 200 <= code < 300 and method in hdrs.METH_CONNECT_ALL ) ================================================ FILE: aiohttp/http.py ================================================ import sys from . import __version__ from .http_exceptions import HttpProcessingError from .http_parser import ( HeadersParser, HttpParser, HttpRequestParser, HttpResponseParser, RawRequestMessage, RawResponseMessage, ) from .http_websocket import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, WS_KEY, WebSocketError, WebSocketReader, WebSocketWriter, WSCloseCode, WSMessage, WSMessageDecodeText, WSMessageNoDecodeText, WSMsgType, ws_ext_gen, ws_ext_parse, ) from .http_writer import HttpVersion, HttpVersion10, HttpVersion11, StreamWriter __all__ = ( "HttpProcessingError", "SERVER_SOFTWARE", # .http_writer "StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11", # .http_parser "HeadersParser", "HttpParser", "HttpRequestParser", "HttpResponseParser", "RawRequestMessage", "RawResponseMessage", # .http_websocket "WS_CLOSED_MESSAGE", "WS_CLOSING_MESSAGE", "WS_KEY", "WebSocketReader", "WebSocketWriter", "ws_ext_gen", "ws_ext_parse", "WSMessage", "WSMessageDecodeText", "WSMessageNoDecodeText", "WebSocketError", "WSMsgType", "WSCloseCode", ) SERVER_SOFTWARE: str = ( f"Python/{sys.version_info[0]}.{sys.version_info[1]} aiohttp/{__version__}" ) ================================================ FILE: aiohttp/http_exceptions.py ================================================ """Low-level http related exceptions.""" from textwrap import indent from multidict import CIMultiDict __all__ = ("HttpProcessingError",) class HttpProcessingError(Exception): """HTTP error. Shortcut for raising HTTP errors with custom code, message and headers. code: HTTP Error code. message: (optional) Error message. headers: (optional) Headers to be sent in response, a list of pairs """ code = 0 message = "" headers = None def __init__( self, *, code: int | None = None, message: str = "", headers: CIMultiDict[str] | None = None, ) -> None: if code is not None: self.code = code self.headers = headers self.message = message def __str__(self) -> str: msg = indent(self.message, " ") return f"{self.code}, message:\n{msg}" def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self.code}, message={self.message!r}>" class BadHttpMessage(HttpProcessingError): code = 400 message = "Bad Request" def __init__( self, message: str, *, headers: CIMultiDict[str] | None = None ) -> None: super().__init__(message=message, headers=headers) self.args = (message,) class HttpBadRequest(BadHttpMessage): code = 400 message = "Bad Request" class PayloadEncodingError(BadHttpMessage): """Base class for payload errors""" class ContentEncodingError(PayloadEncodingError): """Content encoding error.""" class TransferEncodingError(PayloadEncodingError): """transfer encoding error.""" class ContentLengthError(PayloadEncodingError): """Not enough data to satisfy content length header.""" class DecompressSizeError(PayloadEncodingError): """Decompressed size exceeds the configured limit.""" class LineTooLong(BadHttpMessage): def __init__(self, line: bytes, limit: int) -> None: super().__init__(f"Got more than {limit} bytes when reading: {line!r}.") self.args = (line, limit) class InvalidHeader(BadHttpMessage): def __init__(self, hdr: bytes | str) -> None: hdr_s = hdr.decode(errors="backslashreplace") if isinstance(hdr, bytes) else hdr super().__init__(f"Invalid HTTP header: {hdr!r}") self.hdr = hdr_s self.args = (hdr,) class BadStatusLine(BadHttpMessage): def __init__(self, line: str = "", error: str | None = None) -> None: super().__init__(error or f"Bad status line {line!r}") self.args = (line,) self.line = line class BadHttpMethod(BadStatusLine): """Invalid HTTP method in status line.""" def __init__(self, line: str = "", error: str | None = None) -> None: super().__init__(line, error or f"Bad HTTP method in status line {line!r}") class InvalidURLError(BadHttpMessage): pass ================================================ FILE: aiohttp/http_parser.py ================================================ import abc import asyncio import re import string from contextlib import suppress from enum import IntEnum from re import Pattern from typing import Any, ClassVar, Final, Generic, Literal, NamedTuple, TypeVar from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL from . import hdrs from .base_protocol import BaseProtocol from .compression_utils import ( DEFAULT_MAX_DECOMPRESS_SIZE, HAS_BROTLI, HAS_ZSTD, BrotliDecompressor, ZLibDecompressor, ZSTDDecompressor, ) from .helpers import ( _EXC_SENTINEL, DEBUG, EMPTY_BODY_METHODS, EMPTY_BODY_STATUS_CODES, NO_EXTENSIONS, BaseTimerContext, set_exception, ) from .http_exceptions import ( BadHttpMessage, BadHttpMethod, BadStatusLine, ContentEncodingError, ContentLengthError, DecompressSizeError, InvalidHeader, InvalidURLError, LineTooLong, TransferEncodingError, ) from .http_writer import HttpVersion, HttpVersion10 from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import RawHeaders __all__ = ( "HeadersParser", "HttpParser", "HttpRequestParser", "HttpResponseParser", "RawRequestMessage", "RawResponseMessage", ) _SEP = Literal[b"\r\n", b"\n"] ASCIISET: Final[set[str]] = set(string.printable) # See https://www.rfc-editor.org/rfc/rfc9110.html#name-overview # and https://www.rfc-editor.org/rfc/rfc9110.html#name-tokens # # method = token # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / # "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA # token = 1*tchar _TCHAR_SPECIALS: Final[str] = re.escape("!#$%&'*+-.^_`|~") TOKENRE: Final[Pattern[str]] = re.compile(f"[0-9A-Za-z{_TCHAR_SPECIALS}]+") # https://www.rfc-editor.org/rfc/rfc9110#section-5.5-5 _FIELD_VALUE_FORBIDDEN_CTL_RE: Final[Pattern[str]] = re.compile( r"[\x00-\x08\x0a-\x1f\x7f]" ) VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d)\.(\d)", re.ASCII) DIGITS: Final[Pattern[str]] = re.compile(r"\d+", re.ASCII) HEXDIGITS: Final[Pattern[bytes]] = re.compile(rb"[0-9a-fA-F]+") class RawRequestMessage(NamedTuple): method: str path: str version: HttpVersion headers: CIMultiDictProxy[str] raw_headers: RawHeaders should_close: bool compression: str | None upgrade: bool chunked: bool url: URL class RawResponseMessage(NamedTuple): version: HttpVersion code: int reason: str headers: CIMultiDictProxy[str] raw_headers: RawHeaders should_close: bool compression: str | None upgrade: bool chunked: bool _MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) class ParseState(IntEnum): PARSE_NONE = 0 PARSE_LENGTH = 1 PARSE_CHUNKED = 2 PARSE_UNTIL_EOF = 3 class ChunkState(IntEnum): PARSE_CHUNKED_SIZE = 0 PARSE_CHUNKED_CHUNK = 1 PARSE_CHUNKED_CHUNK_EOF = 2 PARSE_TRAILERS = 4 class HeadersParser: def __init__(self, max_field_size: int = 8190, lax: bool = False) -> None: self.max_field_size = max_field_size self._lax = lax def parse_headers( self, lines: list[bytes] ) -> tuple["CIMultiDictProxy[str]", RawHeaders]: headers: CIMultiDict[str] = CIMultiDict() # note: "raw" does not mean inclusion of OWS before/after the field value raw_headers = [] lines_idx = 0 line = lines[lines_idx] line_count = len(lines) while line: # Parse initial header name : value pair. try: bname, bvalue = line.split(b":", 1) except ValueError: raise InvalidHeader(line) from None if len(bname) == 0: raise InvalidHeader(bname) # https://www.rfc-editor.org/rfc/rfc9112.html#section-5.1-2 if {bname[0], bname[-1]} & {32, 9}: # {" ", "\t"} raise InvalidHeader(line) bvalue = bvalue.lstrip(b" \t") name = bname.decode("utf-8", "surrogateescape") if not TOKENRE.fullmatch(name): raise InvalidHeader(bname) # next line lines_idx += 1 line = lines[lines_idx] # consume continuation lines continuation = self._lax and line and line[0] in (32, 9) # (' ', '\t') # Deprecated: https://www.rfc-editor.org/rfc/rfc9112.html#name-obsolete-line-folding if continuation: header_length = len(bvalue) bvalue_lst = [bvalue] while continuation: header_length += len(line) if header_length > self.max_field_size: header_line = bname + b": " + b"".join(bvalue_lst) raise LineTooLong( header_line[:100] + b"...", self.max_field_size ) bvalue_lst.append(line) # next line lines_idx += 1 if lines_idx < line_count: line = lines[lines_idx] if line: continuation = line[0] in (32, 9) # (' ', '\t') else: line = b"" break bvalue = b"".join(bvalue_lst) bvalue = bvalue.strip(b" \t") value = bvalue.decode("utf-8", "surrogateescape") # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-5 if self._lax: if "\n" in value or "\r" in value or "\x00" in value: raise InvalidHeader(bvalue) elif _FIELD_VALUE_FORBIDDEN_CTL_RE.search(value): raise InvalidHeader(bvalue) headers.add(name, value) raw_headers.append((bname, bvalue)) return (CIMultiDictProxy(headers), tuple(raw_headers)) def _is_supported_upgrade(headers: CIMultiDictProxy[str]) -> bool: """Check if the upgrade header is supported.""" u = headers.get(hdrs.UPGRADE, "") # .lower() can transform non-ascii characters. return u.isascii() and u.lower() in {"tcp", "websocket"} class HttpParser(abc.ABC, Generic[_MsgT]): lax: ClassVar[bool] = False def __init__( self, protocol: BaseProtocol, loop: asyncio.AbstractEventLoop, limit: int, max_line_size: int = 8190, max_headers: int = 128, max_field_size: int = 8190, timer: BaseTimerContext | None = None, code: int | None = None, method: str | None = None, payload_exception: type[BaseException] | None = None, response_with_body: bool = True, read_until_eof: bool = False, auto_decompress: bool = True, ) -> None: self.protocol = protocol self.loop = loop self.max_line_size = max_line_size self.max_field_size = max_field_size self.max_headers = max_headers self.timer = timer self.code = code self.method = method self.payload_exception = payload_exception self.response_with_body = response_with_body self.read_until_eof = read_until_eof self._lines: list[bytes] = [] self._tail = b"" self._upgraded = False self._payload = None self._payload_parser: HttpPayloadParser | None = None self._auto_decompress = auto_decompress self._limit = limit self._headers_parser = HeadersParser(max_field_size, self.lax) @abc.abstractmethod def parse_message(self, lines: list[bytes]) -> _MsgT: ... @abc.abstractmethod def _is_chunked_te(self, te: str) -> bool: ... def feed_eof(self) -> _MsgT | None: if self._payload_parser is not None: self._payload_parser.feed_eof() self._payload_parser = None else: # try to extract partial message if self._tail: self._lines.append(self._tail) if self._lines: if self._lines[-1] != "\r\n": self._lines.append(b"") with suppress(Exception): return self.parse_message(self._lines) return None def feed_data( self, data: bytes, SEP: _SEP = b"\r\n", EMPTY: bytes = b"", CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, METH_CONNECT: str = hdrs.METH_CONNECT, SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, ) -> tuple[list[tuple[_MsgT, StreamReader]], bool, bytes]: messages = [] if self._tail: data, self._tail = self._tail + data, b"" data_len = len(data) start_pos = 0 loop = self.loop max_line_length = self.max_line_size should_close = False while start_pos < data_len: # read HTTP message (request/response line + headers), \r\n\r\n # and split by lines if self._payload_parser is None and not self._upgraded: pos = data.find(SEP, start_pos) # consume \r\n if pos == start_pos and not self._lines: start_pos = pos + len(SEP) continue if pos >= start_pos: if should_close: raise BadHttpMessage("Data after `Connection: close`") # line found line = data[start_pos:pos] if SEP == b"\n": # For lax response parsing line = line.rstrip(b"\r") if len(line) > max_line_length: raise LineTooLong(line[:100] + b"...", max_line_length) self._lines.append(line) # After processing the status/request line, everything is a header. max_line_length = self.max_field_size if len(self._lines) > self.max_headers: raise BadHttpMessage("Too many headers received") start_pos = pos + len(SEP) # \r\n\r\n found if self._lines[-1] == EMPTY: max_trailers = self.max_headers - len(self._lines) try: msg: _MsgT = self.parse_message(self._lines) finally: self._lines.clear() def get_content_length() -> int | None: # payload length length_hdr = msg.headers.get(CONTENT_LENGTH) if length_hdr is None: return None # Shouldn't allow +/- or other number formats. # https://www.rfc-editor.org/rfc/rfc9110#section-8.6-2 # msg.headers is already stripped of leading/trailing wsp if not DIGITS.fullmatch(length_hdr): raise InvalidHeader(CONTENT_LENGTH) return int(length_hdr) length = get_content_length() # do not support old websocket spec if SEC_WEBSOCKET_KEY1 in msg.headers: raise InvalidHeader(SEC_WEBSOCKET_KEY1) self._upgraded = msg.upgrade and _is_supported_upgrade( msg.headers ) method = getattr(msg, "method", self.method) # code is only present on responses code = getattr(msg, "code", 0) assert self.protocol is not None # calculate payload empty_body = code in EMPTY_BODY_STATUS_CODES or bool( method and method in EMPTY_BODY_METHODS ) if not empty_body and ( ((length is not None and length > 0) or msg.chunked) and not self._upgraded ): payload = StreamReader( self.protocol, timer=self.timer, loop=loop, limit=self._limit, ) payload_parser = HttpPayloadParser( payload, length=length, chunked=msg.chunked, method=method, compression=msg.compression, code=self.code, response_with_body=self.response_with_body, auto_decompress=self._auto_decompress, lax=self.lax, headers_parser=self._headers_parser, max_line_size=self.max_line_size, max_field_size=self.max_field_size, max_trailers=max_trailers, ) if not payload_parser.done: self._payload_parser = payload_parser elif method == METH_CONNECT: assert isinstance(msg, RawRequestMessage) payload = StreamReader( self.protocol, timer=self.timer, loop=loop, limit=self._limit, ) self._upgraded = True self._payload_parser = HttpPayloadParser( payload, method=msg.method, compression=msg.compression, auto_decompress=self._auto_decompress, lax=self.lax, headers_parser=self._headers_parser, max_line_size=self.max_line_size, max_field_size=self.max_field_size, max_trailers=max_trailers, ) elif not empty_body and length is None and self.read_until_eof: payload = StreamReader( self.protocol, timer=self.timer, loop=loop, limit=self._limit, ) payload_parser = HttpPayloadParser( payload, length=length, chunked=msg.chunked, method=method, compression=msg.compression, code=self.code, response_with_body=self.response_with_body, auto_decompress=self._auto_decompress, lax=self.lax, headers_parser=self._headers_parser, max_line_size=self.max_line_size, max_field_size=self.max_field_size, max_trailers=max_trailers, ) if not payload_parser.done: self._payload_parser = payload_parser else: payload = EMPTY_PAYLOAD messages.append((msg, payload)) should_close = msg.should_close else: self._tail = data[start_pos:] if len(self._tail) > self.max_line_size: raise LineTooLong(self._tail[:100] + b"...", self.max_line_size) data = EMPTY break # no parser, just store elif self._payload_parser is None and self._upgraded: assert not self._lines break # feed payload elif data and start_pos < data_len: assert not self._lines assert self._payload_parser is not None try: eof, data = self._payload_parser.feed_data(data[start_pos:], SEP) except Exception as underlying_exc: reraised_exc: BaseException = underlying_exc if self.payload_exception is not None: reraised_exc = self.payload_exception(str(underlying_exc)) set_exception( self._payload_parser.payload, reraised_exc, underlying_exc, ) eof = True data = b"" if isinstance( underlying_exc, (InvalidHeader, TransferEncodingError) ): raise if eof: start_pos = 0 data_len = len(data) self._payload_parser = None continue else: break if data and start_pos < data_len: data = data[start_pos:] else: data = EMPTY return messages, self._upgraded, data def parse_headers( self, lines: list[bytes] ) -> tuple[ "CIMultiDictProxy[str]", RawHeaders, bool | None, str | None, bool, bool ]: """Parses RFC 5322 headers from a stream. Line continuations are supported. Returns list of header name and value pairs. Header name is in upper case. """ headers, raw_headers = self._headers_parser.parse_headers(lines) close_conn = None encoding = None upgrade = False chunked = False # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-6 # https://www.rfc-editor.org/rfc/rfc9110.html#name-collected-abnf singletons = ( hdrs.CONTENT_LENGTH, hdrs.CONTENT_LOCATION, hdrs.CONTENT_RANGE, hdrs.CONTENT_TYPE, hdrs.ETAG, hdrs.HOST, hdrs.MAX_FORWARDS, hdrs.SERVER, hdrs.TRANSFER_ENCODING, hdrs.USER_AGENT, ) bad_hdr = next((h for h in singletons if len(headers.getall(h, ())) > 1), None) if bad_hdr is not None: raise BadHttpMessage(f"Duplicate '{bad_hdr}' header found.") # keep-alive and protocol switching # RFC 9110 section 7.6.1 defines Connection as a comma-separated list. conn_values = headers.getall(hdrs.CONNECTION, ()) if conn_values: conn_tokens = { token.lower() for conn_value in conn_values for token in (part.strip(" \t") for part in conn_value.split(",")) if token and token.isascii() } if "close" in conn_tokens: close_conn = True elif "keep-alive" in conn_tokens: close_conn = False # https://www.rfc-editor.org/rfc/rfc9110.html#name-101-switching-protocols if "upgrade" in conn_tokens and headers.get(hdrs.UPGRADE): upgrade = True # encoding enc = headers.get(hdrs.CONTENT_ENCODING, "") if enc.isascii() and enc.lower() in {"gzip", "deflate", "br", "zstd"}: encoding = enc # chunking te = headers.get(hdrs.TRANSFER_ENCODING) if te is not None: if self._is_chunked_te(te): chunked = True if hdrs.CONTENT_LENGTH in headers: raise BadHttpMessage( "Transfer-Encoding can't be present with Content-Length", ) return (headers, raw_headers, close_conn, encoding, upgrade, chunked) def set_upgraded(self, val: bool) -> None: """Set connection upgraded (to websocket) mode. :param bool val: new state. """ self._upgraded = val class HttpRequestParser(HttpParser[RawRequestMessage]): """Read request status line. Exception .http_exceptions.BadStatusLine could be raised in case of any errors in status line. Returns RawRequestMessage. """ def parse_message(self, lines: list[bytes]) -> RawRequestMessage: # request line line = lines[0].decode("utf-8", "surrogateescape") try: method, path, version = line.split(" ", maxsplit=2) except ValueError: raise BadHttpMethod(line) from None # method if not TOKENRE.fullmatch(method): raise BadHttpMethod(method) # version match = VERSRE.fullmatch(version) if match is None: raise BadStatusLine(line) version_o = HttpVersion(int(match.group(1)), int(match.group(2))) if method == "CONNECT": # authority-form, # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3 url = URL.build(authority=path, encoded=True) elif path.startswith("/"): # origin-form, # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1 path_part, _hash_separator, url_fragment = path.partition("#") path_part, _question_mark_separator, qs_part = path_part.partition("?") # NOTE: `yarl.URL.build()` is used to mimic what the Cython-based # NOTE: parser does, otherwise it results into the same # NOTE: HTTP Request-Line input producing different # NOTE: `yarl.URL()` objects url = URL.build( path=path_part, query_string=qs_part, fragment=url_fragment, encoded=True, ) elif path == "*" and method == "OPTIONS": # asterisk-form, url = URL(path, encoded=True) else: # absolute-form for proxy maybe, # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2 url = URL(path, encoded=True) if url.scheme == "": # not absolute-form raise InvalidURLError( path.encode(errors="surrogateescape").decode("latin1") ) # read headers ( headers, raw_headers, close, compression, upgrade, chunked, ) = self.parse_headers(lines[1:]) if close is None: # then the headers weren't set in the request if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close close = True else: # HTTP 1.1 must ask to close. close = False return RawRequestMessage( method, path, version_o, headers, raw_headers, close, compression, upgrade, chunked, url, ) def _is_chunked_te(self, te: str) -> bool: # https://www.rfc-editor.org/rfc/rfc9112#section-7.1-3 # "A sender MUST NOT apply the chunked transfer coding more # than once to a message body" parts = [p.strip(" \t") for p in te.split(",")] chunked_count = sum(1 for p in parts if p.isascii() and p.lower() == "chunked") if chunked_count > 1: raise BadHttpMessage("Request has duplicate `chunked` Transfer-Encoding") last = parts[-1] # .lower() transforms some non-ascii chars, so must check first. if last.isascii() and last.lower() == "chunked": return True # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3 raise BadHttpMessage("Request has invalid `Transfer-Encoding`") class HttpResponseParser(HttpParser[RawResponseMessage]): """Read response status line and headers. BadStatusLine could be raised in case of any errors in status line. Returns RawResponseMessage. """ # Lax mode should only be enabled on response parser. lax = not DEBUG def feed_data( self, data: bytes, SEP: _SEP | None = None, *args: Any, **kwargs: Any, ) -> tuple[list[tuple[RawResponseMessage, StreamReader]], bool, bytes]: if SEP is None: SEP = b"\r\n" if DEBUG else b"\n" return super().feed_data(data, SEP, *args, **kwargs) def parse_message(self, lines: list[bytes]) -> RawResponseMessage: line = lines[0].decode("utf-8", "surrogateescape") try: version, status = line.split(maxsplit=1) except ValueError: raise BadStatusLine(line) from None try: status, reason = status.split(maxsplit=1) except ValueError: status = status.strip() reason = "" # version match = VERSRE.fullmatch(version) if match is None: raise BadStatusLine(line) version_o = HttpVersion(int(match.group(1)), int(match.group(2))) # The status code is a three-digit ASCII number, no padding if len(status) != 3 or not DIGITS.fullmatch(status): raise BadStatusLine(line) status_i = int(status) # read headers ( headers, raw_headers, close, compression, upgrade, chunked, ) = self.parse_headers(lines[1:]) if close is None: if version_o <= HttpVersion10: close = True # https://www.rfc-editor.org/rfc/rfc9112.html#name-message-body-length elif 100 <= status_i < 200 or status_i in {204, 304}: close = False elif hdrs.CONTENT_LENGTH in headers or hdrs.TRANSFER_ENCODING in headers: close = False else: # https://www.rfc-editor.org/rfc/rfc9112.html#section-6.3-2.8 close = True return RawResponseMessage( version_o, status_i, reason.strip(), headers, raw_headers, close, compression, upgrade, chunked, ) def _is_chunked_te(self, te: str) -> bool: # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.2 return te.rsplit(",", maxsplit=1)[-1].strip(" \t").lower() == "chunked" class HttpPayloadParser: def __init__( self, payload: StreamReader, length: int | None = None, chunked: bool = False, compression: str | None = None, code: int | None = None, method: str | None = None, response_with_body: bool = True, auto_decompress: bool = True, lax: bool = False, *, headers_parser: HeadersParser, max_line_size: int = 8190, max_field_size: int = 8190, max_trailers: int = 128, ) -> None: self._length = 0 self._type = ParseState.PARSE_UNTIL_EOF self._chunk = ChunkState.PARSE_CHUNKED_SIZE self._chunk_size = 0 self._chunk_tail = b"" self._auto_decompress = auto_decompress self._lax = lax self._headers_parser = headers_parser self._max_line_size = max_line_size self._max_field_size = max_field_size self._max_trailers = max_trailers self._trailer_lines: list[bytes] = [] self.done = False # payload decompression wrapper if response_with_body and compression and self._auto_decompress: real_payload: StreamReader | DeflateBuffer = DeflateBuffer( payload, compression ) else: real_payload = payload # payload parser if not response_with_body: # don't parse payload if it's not expected to be received self._type = ParseState.PARSE_NONE real_payload.feed_eof() self.done = True elif chunked: self._type = ParseState.PARSE_CHUNKED elif length is not None: self._type = ParseState.PARSE_LENGTH self._length = length if self._length == 0: real_payload.feed_eof() self.done = True self.payload = real_payload def feed_eof(self) -> None: if self._type == ParseState.PARSE_UNTIL_EOF: self.payload.feed_eof() elif self._type == ParseState.PARSE_LENGTH: raise ContentLengthError( "Not enough data to satisfy content length header." ) elif self._type == ParseState.PARSE_CHUNKED: raise TransferEncodingError( "Not enough data to satisfy transfer length header." ) def feed_data( self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";" ) -> tuple[bool, bytes]: # Read specified amount of bytes if self._type == ParseState.PARSE_LENGTH: required = self._length self._length = max(required - len(chunk), 0) self.payload.feed_data(chunk[:required]) if self._length == 0: self.payload.feed_eof() return True, chunk[required:] # Chunked transfer encoding parser elif self._type == ParseState.PARSE_CHUNKED: if self._chunk_tail: # We should never have a tail if we're inside the payload body. assert self._chunk != ChunkState.PARSE_CHUNKED_CHUNK # We should check the length is sane. max_line_length = self._max_line_size if self._chunk == ChunkState.PARSE_TRAILERS: max_line_length = self._max_field_size if len(self._chunk_tail) > max_line_length: raise LineTooLong(self._chunk_tail[:100] + b"...", max_line_length) chunk = self._chunk_tail + chunk self._chunk_tail = b"" while chunk: # read next chunk size if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: pos = chunk.find(SEP) if pos >= 0: i = chunk.find(CHUNK_EXT, 0, pos) if i >= 0: size_b = chunk[:i] # strip chunk-extensions # Verify no LF in the chunk-extension if b"\n" in (ext := chunk[i:pos]): exc = TransferEncodingError( f"Unexpected LF in chunk-extension: {ext!r}" ) set_exception(self.payload, exc) raise exc else: size_b = chunk[:pos] if self._lax: # Allow whitespace in lax mode. size_b = size_b.strip() if not re.fullmatch(HEXDIGITS, size_b): exc = TransferEncodingError( chunk[:pos].decode("ascii", "surrogateescape") ) set_exception(self.payload, exc) raise exc size = int(bytes(size_b), 16) chunk = chunk[pos + len(SEP) :] if size == 0: # eof marker self._chunk = ChunkState.PARSE_TRAILERS if self._lax and chunk.startswith(b"\r"): chunk = chunk[1:] else: self._chunk = ChunkState.PARSE_CHUNKED_CHUNK self._chunk_size = size self.payload.begin_http_chunk_receiving() else: self._chunk_tail = chunk return False, b"" # read chunk and feed buffer if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: required = self._chunk_size self._chunk_size = max(required - len(chunk), 0) self.payload.feed_data(chunk[:required]) if self._chunk_size: return False, b"" chunk = chunk[required:] self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF self.payload.end_http_chunk_receiving() # toss the CRLF at the end of the chunk if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF: if self._lax and chunk.startswith(b"\r"): chunk = chunk[1:] if chunk[: len(SEP)] == SEP: chunk = chunk[len(SEP) :] self._chunk = ChunkState.PARSE_CHUNKED_SIZE elif len(chunk) >= len(SEP) or chunk != SEP[: len(chunk)]: exc = TransferEncodingError( "Chunk size mismatch: expected CRLF after chunk data" ) set_exception(self.payload, exc) raise exc else: self._chunk_tail = chunk return False, b"" if self._chunk == ChunkState.PARSE_TRAILERS: pos = chunk.find(SEP) if pos < 0: # No line found self._chunk_tail = chunk return False, b"" line = chunk[:pos] chunk = chunk[pos + len(SEP) :] if SEP == b"\n": # For lax response parsing line = line.rstrip(b"\r") if len(line) > self._max_field_size: raise LineTooLong(line[:100] + b"...", self._max_field_size) self._trailer_lines.append(line) if len(self._trailer_lines) > self._max_trailers: raise BadHttpMessage("Too many trailers received") # \r\n\r\n found, end of stream if self._trailer_lines[-1] == b"": # Headers and trailers are defined the same way, # so we reuse the HeadersParser here. try: trailers, raw_trailers = self._headers_parser.parse_headers( self._trailer_lines ) finally: self._trailer_lines.clear() self.payload.feed_eof() return True, chunk # Read all bytes until eof elif self._type == ParseState.PARSE_UNTIL_EOF: self.payload.feed_data(chunk) return False, b"" class DeflateBuffer: """DeflateStream decompress stream and feed data into specified stream.""" def __init__( self, out: StreamReader, encoding: str | None, max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, ) -> None: self.out = out self.size = 0 out.total_compressed_bytes = self.size self.encoding = encoding self._started_decoding = False self.decompressor: BrotliDecompressor | ZLibDecompressor | ZSTDDecompressor if encoding == "br": if not HAS_BROTLI: raise ContentEncodingError( "Can not decode content-encoding: brotli (br). " "Please install `Brotli`" ) self.decompressor = BrotliDecompressor() elif encoding == "zstd": if not HAS_ZSTD: raise ContentEncodingError( "Can not decode content-encoding: zstandard (zstd). " "Please install `backports.zstd`" ) self.decompressor = ZSTDDecompressor() else: self.decompressor = ZLibDecompressor(encoding=encoding) self._max_decompress_size = max_decompress_size def set_exception( self, exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: set_exception(self.out, exc, exc_cause) def feed_data(self, chunk: bytes) -> None: if not chunk: return self.size += len(chunk) self.out.total_compressed_bytes = self.size # RFC1950 # bits 0..3 = CM = 0b1000 = 8 = "deflate" # bits 4..7 = CINFO = 1..7 = windows size. if ( not self._started_decoding and self.encoding == "deflate" and chunk[0] & 0xF != 8 ): # Change the decoder to decompress incorrectly compressed data # Actually we should issue a warning about non-RFC-compliant data. self.decompressor = ZLibDecompressor( encoding=self.encoding, suppress_deflate_header=True ) try: # Decompress with limit + 1 so we can detect if output exceeds limit chunk = self.decompressor.decompress_sync( chunk, max_length=self._max_decompress_size + 1 ) except Exception: raise ContentEncodingError( "Can not decode content-encoding: %s" % self.encoding ) self._started_decoding = True # Check if decompression limit was exceeded if len(chunk) > self._max_decompress_size: raise DecompressSizeError( "Decompressed data exceeds the configured limit of %d bytes" % self._max_decompress_size ) if chunk: self.out.feed_data(chunk) def feed_eof(self) -> None: chunk = self.decompressor.flush() if chunk or self.size > 0: self.out.feed_data(chunk) # decompressor is not brotli unless encoding is "br" if self.encoding == "deflate" and not self.decompressor.eof: # type: ignore[union-attr] raise ContentEncodingError("deflate") self.out.feed_eof() def begin_http_chunk_receiving(self) -> None: self.out.begin_http_chunk_receiving() def end_http_chunk_receiving(self) -> None: self.out.end_http_chunk_receiving() HttpRequestParserPy = HttpRequestParser HttpResponseParserPy = HttpResponseParser RawRequestMessagePy = RawRequestMessage RawResponseMessagePy = RawResponseMessage with suppress(ImportError): if not NO_EXTENSIONS: from ._http_parser import ( # type: ignore[import-not-found,no-redef] HttpRequestParser, HttpResponseParser, RawRequestMessage, RawResponseMessage, ) HttpRequestParserC = HttpRequestParser HttpResponseParserC = HttpResponseParser RawRequestMessageC = RawRequestMessage RawResponseMessageC = RawResponseMessage ================================================ FILE: aiohttp/http_websocket.py ================================================ """WebSocket protocol versions 13 and 8.""" from ._websocket.helpers import WS_KEY, ws_ext_gen, ws_ext_parse from ._websocket.models import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, WebSocketError, WSCloseCode, WSHandshakeError, WSMessage, WSMessageBinary, WSMessageClose, WSMessageClosed, WSMessageClosing, WSMessageContinuation, WSMessageDecodeText, WSMessageError, WSMessageNoDecodeText, WSMessagePing, WSMessagePong, WSMessageText, WSMessageTextBytes, WSMsgType, ) from ._websocket.reader import WebSocketReader from ._websocket.writer import WebSocketWriter # Messages that the WebSocketResponse.receive needs to handle internally _INTERNAL_RECEIVE_TYPES = frozenset( (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.PING, WSMsgType.PONG) ) __all__ = ( "WS_CLOSED_MESSAGE", "WS_CLOSING_MESSAGE", "WS_KEY", "WebSocketReader", "WebSocketWriter", "WSMessage", "WSMessageDecodeText", "WSMessageNoDecodeText", "WebSocketError", "WSMsgType", "WSCloseCode", "ws_ext_gen", "ws_ext_parse", "WSMessageError", "WSHandshakeError", "WSMessageClose", "WSMessageClosed", "WSMessageClosing", "WSMessagePong", "WSMessageBinary", "WSMessageText", "WSMessageTextBytes", "WSMessagePing", "WSMessageContinuation", ) ================================================ FILE: aiohttp/http_writer.py ================================================ """Http related parsers and protocol.""" import asyncio import sys from typing import ( # noqa TYPE_CHECKING, Any, Awaitable, Callable, Iterable, List, NamedTuple, Optional, Union, ) from multidict import CIMultiDict from .abc import AbstractStreamWriter from .base_protocol import BaseProtocol from .client_exceptions import ClientConnectionResetError from .compression_utils import ZLibCompressor from .helpers import NO_EXTENSIONS __all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") MIN_PAYLOAD_FOR_WRITELINES = 2048 IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2) IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9) SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9 # writelines is not safe for use # on Python 3.12+ until 3.12.9 # on Python 3.13+ until 3.13.2 # and on older versions it not any faster than write # CVE-2024-12254: https://github.com/python/cpython/pull/127656 class HttpVersion(NamedTuple): major: int minor: int HttpVersion10 = HttpVersion(1, 0) HttpVersion11 = HttpVersion(1, 1) _T_OnChunkSent = Optional[ Callable[ [Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]], Awaitable[None], ] ] _T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]] class StreamWriter(AbstractStreamWriter): length: int | None = None chunked: bool = False _eof: bool = False _compress: ZLibCompressor | None = None def __init__( self, protocol: BaseProtocol, loop: asyncio.AbstractEventLoop, on_chunk_sent: _T_OnChunkSent = None, on_headers_sent: _T_OnHeadersSent = None, ) -> None: self._protocol = protocol self.loop = loop self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent self._on_headers_sent: _T_OnHeadersSent = on_headers_sent self._headers_buf: bytes | None = None self._headers_written: bool = False @property def transport(self) -> asyncio.Transport | None: return self._protocol.transport @property def protocol(self) -> BaseProtocol: return self._protocol def enable_chunking(self) -> None: self.chunked = True def enable_compression( self, encoding: str = "deflate", strategy: int | None = None ) -> None: self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) def _write( self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] ) -> None: size = len(chunk) self.buffer_size += size self.output_size += size transport = self._protocol.transport if transport is None or transport.is_closing(): raise ClientConnectionResetError("Cannot write to closing transport") transport.write(chunk) def _writelines( self, chunks: Iterable[ Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] ], ) -> None: size = 0 for chunk in chunks: size += len(chunk) self.buffer_size += size self.output_size += size transport = self._protocol.transport if transport is None or transport.is_closing(): raise ClientConnectionResetError("Cannot write to closing transport") if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES: transport.write(b"".join(chunks)) else: transport.writelines(chunks) def _write_chunked_payload( self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] ) -> None: """Write a chunk with proper chunked encoding.""" chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii") self._writelines((chunk_len_pre, chunk, b"\r\n")) def _send_headers_with_payload( self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"], is_eof: bool, ) -> None: """Send buffered headers with payload, coalescing into single write.""" # Mark headers as written self._headers_written = True headers_buf = self._headers_buf self._headers_buf = None if TYPE_CHECKING: # Safe because callers (write() and write_eof()) only invoke this method # after checking that self._headers_buf is truthy assert headers_buf is not None if not self.chunked: # Non-chunked: coalesce headers with body if chunk: self._writelines((headers_buf, chunk)) else: self._write(headers_buf) return # Coalesce headers with chunked data if chunk: chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii") if is_eof: self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n0\r\n\r\n")) else: self._writelines((headers_buf, chunk_len_pre, chunk, b"\r\n")) elif is_eof: self._writelines((headers_buf, b"0\r\n\r\n")) else: self._write(headers_buf) async def write( self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"], *, drain: bool = True, LIMIT: int = 0x10000, ) -> None: """ Writes chunk of data to a stream. write_eof() indicates end of stream. writer can't be used after write_eof() method being called. write() return drain future. """ if self._on_chunk_sent is not None: await self._on_chunk_sent(chunk) if isinstance(chunk, memoryview): if chunk.nbytes != len(chunk): # just reshape it chunk = chunk.cast("c") if self._compress is not None: chunk = await self._compress.compress(chunk) if not chunk: return if self.length is not None: chunk_len = len(chunk) if self.length >= chunk_len: self.length = self.length - chunk_len else: chunk = chunk[: self.length] self.length = 0 if not chunk: return # Handle buffered headers for small payload optimization if self._headers_buf and not self._headers_written: self._send_headers_with_payload(chunk, False) if drain and self.buffer_size > LIMIT: self.buffer_size = 0 await self.drain() return if chunk: if self.chunked: self._write_chunked_payload(chunk) else: self._write(chunk) if drain and self.buffer_size > LIMIT: self.buffer_size = 0 await self.drain() async def write_headers( self, status_line: str, headers: "CIMultiDict[str]" ) -> None: """Write headers to the stream.""" if self._on_headers_sent is not None: await self._on_headers_sent(headers) # status + headers buf = _serialize_headers(status_line, headers) self._headers_written = False self._headers_buf = buf def send_headers(self) -> None: """Force sending buffered headers if not already sent.""" if not self._headers_buf or self._headers_written: return self._headers_written = True headers_buf = self._headers_buf self._headers_buf = None if TYPE_CHECKING: # Safe because we only enter this block when self._headers_buf is truthy assert headers_buf is not None self._write(headers_buf) def set_eof(self) -> None: """Indicate that the message is complete.""" if self._eof: return # If headers haven't been sent yet, send them now # This handles the case where there's no body at all if self._headers_buf and not self._headers_written: self._headers_written = True headers_buf = self._headers_buf self._headers_buf = None if TYPE_CHECKING: # Safe because we only enter this block when self._headers_buf is truthy assert headers_buf is not None # Combine headers and chunked EOF marker in a single write if self.chunked: self._writelines((headers_buf, b"0\r\n\r\n")) else: self._write(headers_buf) elif self.chunked and self._headers_written: # Headers already sent, just send the final chunk marker self._write(b"0\r\n\r\n") self._eof = True async def write_eof(self, chunk: bytes = b"") -> None: if self._eof: return if chunk and self._on_chunk_sent is not None: await self._on_chunk_sent(chunk) # Handle body/compression if self._compress: chunks: list[bytes] = [] chunks_len = 0 if chunk and (compressed_chunk := await self._compress.compress(chunk)): chunks_len = len(compressed_chunk) chunks.append(compressed_chunk) flush_chunk = self._compress.flush() chunks_len += len(flush_chunk) chunks.append(flush_chunk) assert chunks_len # Send buffered headers with compressed data if not yet sent if self._headers_buf and not self._headers_written: self._headers_written = True headers_buf = self._headers_buf self._headers_buf = None if self.chunked: # Coalesce headers with compressed chunked data chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii") self._writelines( (headers_buf, chunk_len_pre, *chunks, b"\r\n0\r\n\r\n") ) else: # Coalesce headers with compressed data self._writelines((headers_buf, *chunks)) await self.drain() self._eof = True return # Headers already sent, just write compressed data if self.chunked: chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii") self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n")) elif len(chunks) > 1: self._writelines(chunks) else: self._write(chunks[0]) await self.drain() self._eof = True return # No compression - send buffered headers if not yet sent if self._headers_buf and not self._headers_written: # Use helper to send headers with payload self._send_headers_with_payload(chunk, True) await self.drain() self._eof = True return # Handle remaining body if self.chunked: if chunk: # Write final chunk with EOF marker self._writelines( (f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n0\r\n\r\n") ) else: self._write(b"0\r\n\r\n") await self.drain() self._eof = True return if chunk: self._write(chunk) await self.drain() self._eof = True async def drain(self) -> None: """Flush the write buffer. The intended use is to write await w.write(data) await w.drain() """ protocol = self._protocol if protocol.transport is not None and protocol._paused: await protocol._drain_helper() def _safe_header(string: str) -> str: if "\r" in string or "\n" in string or "\x00" in string: raise ValueError( "Newline, carriage return, or null byte detected in headers. " "Potential header injection attack." ) return string def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes: _safe_header(status_line) headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items()) line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n" return line.encode("utf-8") _serialize_headers = _py_serialize_headers try: import aiohttp._http_writer as _http_writer # type: ignore[import-not-found] _c_serialize_headers = _http_writer._serialize_headers if not NO_EXTENSIONS: _serialize_headers = _c_serialize_headers except ImportError: pass ================================================ FILE: aiohttp/log.py ================================================ import logging access_logger = logging.getLogger("aiohttp.access") client_logger = logging.getLogger("aiohttp.client") internal_logger = logging.getLogger("aiohttp.internal") server_logger = logging.getLogger("aiohttp.server") web_logger = logging.getLogger("aiohttp.web") ws_logger = logging.getLogger("aiohttp.websocket") ================================================ FILE: aiohttp/multipart.py ================================================ import base64 import binascii import json import re import sys import uuid import warnings from collections import deque from collections.abc import AsyncIterator, Iterator, Mapping, Sequence from types import TracebackType from typing import TYPE_CHECKING, Any, Union, cast from urllib.parse import parse_qsl, unquote, urlencode from multidict import CIMultiDict, CIMultiDictProxy from .abc import AbstractStreamWriter from .compression_utils import ( DEFAULT_MAX_DECOMPRESS_SIZE, ZLibCompressor, ZLibDecompressor, ) from .hdrs import ( CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING, CONTENT_TYPE, ) from .helpers import CHAR, TOKEN, parse_mimetype, reify from .http import HeadersParser from .http_exceptions import BadHttpMessage from .log import internal_logger from .payload import ( JsonPayload, LookupError, Order, Payload, StringPayload, get_payload, payload_type, ) from .streams import StreamReader if sys.version_info >= (3, 11): from typing import Self else: from typing import TypeVar Self = TypeVar("Self", bound="BodyPartReader") __all__ = ( "MultipartReader", "MultipartWriter", "BodyPartReader", "BadContentDispositionHeader", "BadContentDispositionParam", "parse_content_disposition", "content_disposition_filename", ) if TYPE_CHECKING: from .client_reqrep import ClientResponse class BadContentDispositionHeader(RuntimeWarning): pass class BadContentDispositionParam(RuntimeWarning): pass def parse_content_disposition( header: str | None, ) -> tuple[str | None, dict[str, str]]: def is_token(string: str) -> bool: return bool(string) and TOKEN >= set(string) def is_quoted(string: str) -> bool: return string[0] == string[-1] == '"' def is_rfc5987(string: str) -> bool: return is_token(string) and string.count("'") == 2 def is_extended_param(string: str) -> bool: return string.endswith("*") def is_continuous_param(string: str) -> bool: pos = string.find("*") + 1 if not pos: return False substring = string[pos:-1] if string.endswith("*") else string[pos:] return substring.isdigit() def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str: return re.sub(f"\\\\([{chars}])", "\\1", text) if not header: return None, {} disptype, *parts = header.split(";") if not is_token(disptype): warnings.warn(BadContentDispositionHeader(header)) return None, {} params: dict[str, str] = {} while parts: item = parts.pop(0) if not item: # To handle trailing semicolons warnings.warn(BadContentDispositionHeader(header)) continue if "=" not in item: warnings.warn(BadContentDispositionHeader(header)) return None, {} key, value = item.split("=", 1) key = key.lower().strip() value = value.lstrip() if key in params: warnings.warn(BadContentDispositionHeader(header)) return None, {} if not is_token(key): warnings.warn(BadContentDispositionParam(item)) continue elif is_continuous_param(key): if is_quoted(value): value = unescape(value[1:-1]) elif not is_token(value): warnings.warn(BadContentDispositionParam(item)) continue elif is_extended_param(key): if is_rfc5987(value): encoding, _, value = value.split("'", 2) encoding = encoding or "utf-8" else: warnings.warn(BadContentDispositionParam(item)) continue try: value = unquote(value, encoding, "strict") except UnicodeDecodeError: # pragma: nocover warnings.warn(BadContentDispositionParam(item)) continue else: failed = True if is_quoted(value): failed = False value = unescape(value[1:-1].lstrip("\\/")) elif is_token(value): failed = False elif parts: # maybe just ; in filename, in any case this is just # one case fix, for proper fix we need to redesign parser _value = f"{value};{parts[0]}" if is_quoted(_value): parts.pop(0) value = unescape(_value[1:-1].lstrip("\\/")) failed = False if failed: warnings.warn(BadContentDispositionHeader(header)) return None, {} params[key] = value return disptype.lower(), params def content_disposition_filename( params: Mapping[str, str], name: str = "filename" ) -> str | None: name_suf = "%s*" % name if not params: return None elif name_suf in params: return params[name_suf] elif name in params: return params[name] else: parts = [] fnparams = sorted( (key, value) for key, value in params.items() if key.startswith(name_suf) ) for num, (key, value) in enumerate(fnparams): _, tail = key.split("*", 1) if tail.endswith("*"): tail = tail[:-1] if tail == str(num): parts.append(value) else: break if not parts: return None value = "".join(parts) if "'" in value: encoding, _, value = value.split("'", 2) encoding = encoding or "utf-8" return unquote(value, encoding, "strict") return value class MultipartResponseWrapper: """Wrapper around the MultipartReader. It takes care about underlying connection and close it when it needs in. """ def __init__( self, resp: "ClientResponse", stream: "MultipartReader", ) -> None: self.resp = resp self.stream = stream def __aiter__(self) -> "MultipartResponseWrapper": return self async def __anext__( self, ) -> Union["MultipartReader", "BodyPartReader"]: part = await self.next() if part is None: raise StopAsyncIteration return part def at_eof(self) -> bool: """Returns True when all response data had been read.""" return self.resp.content.at_eof() async def next( self, ) -> Union["MultipartReader", "BodyPartReader"] | None: """Emits next multipart reader object.""" item = await self.stream.next() if self.stream.at_eof(): await self.release() return item async def release(self) -> None: """Release the connection gracefully. All remaining content is read to the void. """ self.resp.release() class BodyPartReader: """Multipart reader for single body part.""" chunk_size = 8192 def __init__( self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader, *, subtype: str = "mixed", default_charset: str | None = None, max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE, ) -> None: self.headers = headers self._boundary = boundary self._boundary_len = len(boundary) + 2 # Boundary + \r\n self._content = content self._default_charset = default_charset self._at_eof = False self._is_form_data = subtype == "form-data" # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None) self._length = int(length) if length is not None else None self._read_bytes = 0 self._unread: deque[bytes] = deque() self._prev_chunk: bytes | None = None self._content_eof = 0 self._cache: dict[str, Any] = {} self._max_decompress_size = max_decompress_size def __aiter__(self) -> Self: return self async def __anext__(self) -> bytes: part = await self.next() if part is None: raise StopAsyncIteration return part async def next(self) -> bytes | None: item = await self.read() if not item: return None return item async def read(self, *, decode: bool = False) -> bytes: """Reads body part data. decode: Decodes data following by encoding method from Content-Encoding header. If it missed data remains untouched """ if self._at_eof: return b"" data = bytearray() while not self._at_eof: data.extend(await self.read_chunk(self.chunk_size)) # https://github.com/python/mypy/issues/17537 if decode: # type: ignore[unreachable] decoded_data = bytearray() async for d in self.decode_iter(data): decoded_data.extend(d) return decoded_data return data async def read_chunk(self, size: int = chunk_size) -> bytes: """Reads body part content chunk of the specified size. size: chunk size """ if self._at_eof: return b"" if self._length: chunk = await self._read_chunk_from_length(size) else: chunk = await self._read_chunk_from_stream(size) # For the case of base64 data, we must read a fragment of size with a # remainder of 0 by dividing by 4 for string without symbols \n or \r encoding = self.headers.get(CONTENT_TRANSFER_ENCODING) if encoding and encoding.lower() == "base64": stripped_chunk = b"".join(chunk.split()) remainder = len(stripped_chunk) % 4 while remainder != 0 and not self.at_eof(): over_chunk_size = 4 - remainder over_chunk = b"" if self._prev_chunk: over_chunk = self._prev_chunk[:over_chunk_size] self._prev_chunk = self._prev_chunk[len(over_chunk) :] if len(over_chunk) != over_chunk_size: over_chunk += await self._content.read(4 - len(over_chunk)) if not over_chunk: self._at_eof = True stripped_chunk += b"".join(over_chunk.split()) chunk += over_chunk remainder = len(stripped_chunk) % 4 self._read_bytes += len(chunk) if self._read_bytes == self._length: self._at_eof = True if self._at_eof and await self._content.readline() != b"\r\n": raise ValueError("Reader did not read all the data or it is malformed") return chunk async def _read_chunk_from_length(self, size: int) -> bytes: # Reads body part content chunk of the specified size. # The body part must has Content-Length header with proper value. assert self._length is not None, "Content-Length required for chunked read" chunk_size = min(size, self._length - self._read_bytes) chunk = await self._content.read(chunk_size) if self._content.at_eof(): self._at_eof = True return chunk async def _read_chunk_from_stream(self, size: int) -> bytes: # Reads content chunk of body part with unknown length. # The Content-Length header for body part is not necessary. assert ( size >= self._boundary_len ), "Chunk size must be greater or equal than boundary length + 2" first_chunk = self._prev_chunk is None if first_chunk: # We need to re-add the CRLF that got removed from headers parsing. self._prev_chunk = b"\r\n" + await self._content.read(size) chunk = b"" # content.read() may return less than size, so we need to loop to ensure # we have enough data to detect the boundary. while len(chunk) < self._boundary_len: chunk += await self._content.read(size) self._content_eof += int(self._content.at_eof()) if self._content_eof > 2: raise ValueError("Reading after EOF") if self._content_eof: break if len(chunk) > size: self._content.unread_data(chunk[size:]) chunk = chunk[:size] assert self._prev_chunk is not None window = self._prev_chunk + chunk sub = b"\r\n" + self._boundary if first_chunk: idx = window.find(sub) else: idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub))) if idx >= 0: # pushing boundary back to content with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) self._content.unread_data(window[idx:]) self._prev_chunk = self._prev_chunk[:idx] chunk = window[len(self._prev_chunk) : idx] if not chunk: self._at_eof = True result = self._prev_chunk[2 if first_chunk else 0 :] # Strip initial CRLF self._prev_chunk = chunk return result async def readline(self) -> bytes: """Reads body part by line by line.""" if self._at_eof: return b"" if self._unread: line = self._unread.popleft() else: line = await self._content.readline() if line.startswith(self._boundary): # the very last boundary may not come with \r\n, # so set single rules for everyone sline = line.rstrip(b"\r\n") boundary = self._boundary last_boundary = self._boundary + b"--" # ensure that we read exactly the boundary, not something alike if sline == boundary or sline == last_boundary: self._at_eof = True self._unread.append(line) return b"" else: next_line = await self._content.readline() if next_line.startswith(self._boundary): line = line[:-2] # strip CRLF but only once self._unread.append(next_line) return line async def release(self) -> None: """Like read(), but reads all the data to the void.""" if self._at_eof: return while not self._at_eof: await self.read_chunk(self.chunk_size) async def text(self, *, encoding: str | None = None) -> str: """Like read(), but assumes that body part contains text data.""" data = await self.read(decode=True) # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send encoding = encoding or self.get_charset(default="utf-8") return data.decode(encoding) async def json(self, *, encoding: str | None = None) -> dict[str, Any] | None: """Like read(), but assumes that body parts contains JSON data.""" data = await self.read(decode=True) if not data: return None encoding = encoding or self.get_charset(default="utf-8") return cast(dict[str, Any], json.loads(data.decode(encoding))) async def form(self, *, encoding: str | None = None) -> list[tuple[str, str]]: """Like read(), but assumes that body parts contain form urlencoded data.""" data = await self.read(decode=True) if not data: return [] if encoding is not None: real_encoding = encoding else: real_encoding = self.get_charset(default="utf-8") try: decoded_data = data.rstrip().decode(real_encoding) except UnicodeDecodeError: raise ValueError("data cannot be decoded with %s encoding" % real_encoding) return parse_qsl( decoded_data, keep_blank_values=True, encoding=real_encoding, ) def at_eof(self) -> bool: """Returns True if the boundary was reached or False otherwise.""" return self._at_eof def _apply_content_transfer_decoding(self, data: bytes) -> bytes: """Apply Content-Transfer-Encoding decoding if header is present.""" if CONTENT_TRANSFER_ENCODING in self.headers: return self._decode_content_transfer(data) return data def _needs_content_decoding(self) -> bool: """Check if Content-Encoding decoding should be applied.""" # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 return not self._is_form_data and CONTENT_ENCODING in self.headers def decode(self, data: bytes) -> bytes: """Decodes data synchronously. Decodes data according the specified Content-Encoding or Content-Transfer-Encoding headers value. Note: For large payloads, consider using decode_iter() instead to avoid blocking the event loop during decompression. """ data = self._apply_content_transfer_decoding(data) if self._needs_content_decoding(): return self._decode_content(data) return data async def decode_iter(self, data: bytes) -> AsyncIterator[bytes]: """Async generator that yields decoded data chunks. Decodes data according the specified Content-Encoding or Content-Transfer-Encoding headers value. This method offloads decompression to an executor for large payloads to avoid blocking the event loop. """ data = self._apply_content_transfer_decoding(data) if self._needs_content_decoding(): async for d in self._decode_content_async(data): yield d else: yield data def _decode_content(self, data: bytes) -> bytes: encoding = self.headers.get(CONTENT_ENCODING, "").lower() if encoding == "identity": return data if encoding in {"deflate", "gzip"}: return ZLibDecompressor( encoding=encoding, suppress_deflate_header=True, ).decompress_sync(data, max_length=self._max_decompress_size) raise RuntimeError(f"unknown content encoding: {encoding}") async def _decode_content_async(self, data: bytes) -> AsyncIterator[bytes]: encoding = self.headers.get(CONTENT_ENCODING, "").lower() if encoding == "identity": yield data elif encoding in {"deflate", "gzip"}: d = ZLibDecompressor( encoding=encoding, suppress_deflate_header=True, ) yield await d.decompress(data, max_length=self._max_decompress_size) else: raise RuntimeError(f"unknown content encoding: {encoding}") def _decode_content_transfer(self, data: bytes) -> bytes: encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() if encoding == "base64": return base64.b64decode(data) elif encoding == "quoted-printable": return binascii.a2b_qp(data) elif encoding in ("binary", "8bit", "7bit"): return data else: raise RuntimeError(f"unknown content transfer encoding: {encoding}") def get_charset(self, default: str) -> str: """Returns charset parameter from Content-Type header or default.""" ctype = self.headers.get(CONTENT_TYPE, "") mimetype = parse_mimetype(ctype) return mimetype.parameters.get("charset", self._default_charset or default) @reify def name(self) -> str | None: """Returns name specified in Content-Disposition header. If the header is missing or malformed, returns None. """ _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) return content_disposition_filename(params, "name") @reify def filename(self) -> str | None: """Returns filename specified in Content-Disposition header. Returns None if the header is missing or malformed. """ _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) return content_disposition_filename(params, "filename") @payload_type(BodyPartReader, order=Order.try_first) class BodyPartReaderPayload(Payload): _value: BodyPartReader # _autoclose = False (inherited) - Streaming reader that may have resources def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: super().__init__(value, *args, **kwargs) params: dict[str, str] = {} if value.name is not None: params["name"] = value.name if value.filename is not None: params["filename"] = value.filename if params: self.set_content_disposition("attachment", True, **params) def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: raise TypeError("Unable to decode.") async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: """Raises TypeError as body parts should be consumed via write(). This is intentional: BodyPartReader payloads are designed for streaming large data (potentially gigabytes) and must be consumed only once via the write() method to avoid memory exhaustion. They cannot be buffered in memory for reuse. """ raise TypeError("Unable to read body part as bytes. Use write() to consume.") async def write(self, writer: AbstractStreamWriter) -> None: field = self._value while chunk := await field.read_chunk(size=2**18): async for d in field.decode_iter(chunk): await writer.write(d) class MultipartReader: """Multipart body reader.""" #: Response wrapper, used when multipart readers constructs from response. response_wrapper_cls = MultipartResponseWrapper #: Multipart reader class, used to handle multipart/* body parts. #: None points to type(self) multipart_reader_cls: type["MultipartReader"] | None = None #: Body part reader class for non multipart/* content types. part_reader_cls = BodyPartReader def __init__( self, headers: Mapping[str, str], content: StreamReader, *, max_field_size: int = 8190, max_headers: int = 128, ) -> None: self._mimetype = parse_mimetype(headers[CONTENT_TYPE]) assert self._mimetype.type == "multipart", "multipart/* content type expected" if "boundary" not in self._mimetype.parameters: raise ValueError( "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE] ) self.headers = headers self._boundary = ("--" + self._get_boundary()).encode() self._content = content self._default_charset: str | None = None self._last_part: MultipartReader | BodyPartReader | None = None self._max_field_size = max_field_size self._max_headers = max_headers self._at_eof = False self._at_bof = True self._unread: list[bytes] = [] def __aiter__(self) -> Self: return self async def __anext__( self, ) -> Union["MultipartReader", BodyPartReader] | None: part = await self.next() if part is None: raise StopAsyncIteration return part @classmethod def from_response( cls, response: "ClientResponse", ) -> MultipartResponseWrapper: """Constructs reader instance from HTTP response. :param response: :class:`~aiohttp.client.ClientResponse` instance """ obj = cls.response_wrapper_cls( response, cls(response.headers, response.content) ) return obj def at_eof(self) -> bool: """Returns True if the final boundary was reached, false otherwise.""" return self._at_eof async def next( self, ) -> Union["MultipartReader", BodyPartReader] | None: """Emits the next multipart body part.""" # So, if we're at BOF, we need to skip till the boundary. if self._at_eof: return None await self._maybe_release_last_part() if self._at_bof: await self._read_until_first_boundary() self._at_bof = False else: await self._read_boundary() if self._at_eof: # we just read the last boundary, nothing to do there # https://github.com/python/mypy/issues/17537 return None # type: ignore[unreachable] part = await self.fetch_next_part() # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6 if ( self._last_part is None and self._mimetype.subtype == "form-data" and isinstance(part, BodyPartReader) ): _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION)) if params.get("name") == "_charset_": # Longest encoding in https://encoding.spec.whatwg.org/encodings.json # is 19 characters, so 32 should be more than enough for any valid encoding. charset = await part.read_chunk(32) if len(charset) > 31: raise RuntimeError("Invalid default charset") self._default_charset = charset.strip().decode() part = await self.fetch_next_part() self._last_part = part return self._last_part async def release(self) -> None: """Reads all the body parts to the void till the final boundary.""" while not self._at_eof: item = await self.next() if item is None: break await item.release() async def fetch_next_part( self, ) -> Union["MultipartReader", BodyPartReader]: """Returns the next body part reader.""" headers = await self._read_headers() return self._get_part_reader(headers) def _get_part_reader( self, headers: "CIMultiDictProxy[str]", ) -> Union["MultipartReader", BodyPartReader]: """Dispatches the response by the `Content-Type` header. Returns a suitable reader instance. :param dict headers: Response headers """ ctype = headers.get(CONTENT_TYPE, "") mimetype = parse_mimetype(ctype) if mimetype.type == "multipart": if self.multipart_reader_cls is None: return type(self)(headers, self._content) return self.multipart_reader_cls( headers, self._content, max_field_size=self._max_field_size, max_headers=self._max_headers, ) else: return self.part_reader_cls( self._boundary, headers, self._content, subtype=self._mimetype.subtype, default_charset=self._default_charset, ) def _get_boundary(self) -> str: boundary = self._mimetype.parameters["boundary"] if len(boundary) > 70: raise ValueError("boundary %r is too long (70 chars max)" % boundary) return boundary async def _readline(self) -> bytes: if self._unread: return self._unread.pop() return await self._content.readline() async def _read_until_first_boundary(self) -> None: while True: chunk = await self._readline() if chunk == b"": raise ValueError(f"Could not find starting boundary {self._boundary!r}") chunk = chunk.rstrip() if chunk == self._boundary: return elif chunk == self._boundary + b"--": self._at_eof = True return async def _read_boundary(self) -> None: chunk = (await self._readline()).rstrip() if chunk == self._boundary: pass elif chunk == self._boundary + b"--": self._at_eof = True epilogue = await self._readline() next_line = await self._readline() # the epilogue is expected and then either the end of input or the # parent multipart boundary, if the parent boundary is found then # it should be marked as unread and handed to the parent for # processing if next_line[:2] == b"--": self._unread.append(next_line) # otherwise the request is likely missing an epilogue and both # lines should be passed to the parent for processing # (this handles the old behavior gracefully) else: self._unread.extend([next_line, epilogue]) else: raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}") async def _read_headers(self) -> "CIMultiDictProxy[str]": lines = [] while True: chunk = await self._content.readline(max_line_length=self._max_field_size) chunk = chunk.rstrip(b"\r\n") lines.append(chunk) if not chunk: break if len(lines) > self._max_headers: raise BadHttpMessage("Too many headers received") parser = HeadersParser(max_field_size=self._max_field_size) headers, raw_headers = parser.parse_headers(lines) return headers async def _maybe_release_last_part(self) -> None: """Ensures that the last read body part is read completely.""" if self._last_part is not None: if not self._last_part.at_eof(): await self._last_part.release() self._unread.extend(self._last_part._unread) self._last_part = None _Part = tuple[Payload, str, str] class MultipartWriter(Payload): """Multipart body writer.""" _value: None # _consumed = False (inherited) - Can be encoded multiple times _autoclose = True # No file handles, just collects parts in memory def __init__(self, subtype: str = "mixed", boundary: str | None = None) -> None: boundary = boundary if boundary is not None else uuid.uuid4().hex # The underlying Payload API demands a str (utf-8), not bytes, # so we need to ensure we don't lose anything during conversion. # As a result, require the boundary to be ASCII only. # In both situations. try: self._boundary = boundary.encode("ascii") except UnicodeEncodeError: raise ValueError("boundary should contain ASCII only chars") from None if len(boundary) > 70: raise ValueError("boundary %r is too long (70 chars max)" % boundary) ctype = f"multipart/{subtype}; boundary={self._boundary_value}" super().__init__(None, content_type=ctype) self._parts: list[_Part] = [] self._is_form_data = subtype == "form-data" def __enter__(self) -> "MultipartWriter": return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: pass def __iter__(self) -> Iterator[_Part]: return iter(self._parts) def __len__(self) -> int: return len(self._parts) def __bool__(self) -> bool: return True _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z") _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]") @property def _boundary_value(self) -> str: """Wrap boundary parameter value in quotes, if necessary. Reads self.boundary and returns a unicode string. """ # Refer to RFCs 7231, 7230, 5234. # # parameter = token "=" ( token / quoted-string ) # token = 1*tchar # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text # obs-text = %x80-FF # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" # / DIGIT / ALPHA # ; any VCHAR, except delimiters # VCHAR = %x21-7E value = self._boundary if re.match(self._valid_tchar_regex, value): return value.decode("ascii") # cannot fail if re.search(self._invalid_qdtext_char_regex, value): raise ValueError("boundary value contains invalid characters") # escape %x5C and %x22 quoted_value_content = value.replace(b"\\", b"\\\\") quoted_value_content = quoted_value_content.replace(b'"', b'\\"') return '"' + quoted_value_content.decode("ascii") + '"' @property def boundary(self) -> str: return self._boundary.decode("ascii") def append(self, obj: Any, headers: Mapping[str, str] | None = None) -> Payload: if headers is None: headers = CIMultiDict() if isinstance(obj, Payload): obj.headers.update(headers) return self.append_payload(obj) else: try: payload = get_payload(obj, headers=headers) except LookupError: raise TypeError("Cannot create payload from %r" % obj) else: return self.append_payload(payload) def append_payload(self, payload: Payload) -> Payload: """Adds a new body part to multipart writer.""" encoding: str | None = None te_encoding: str | None = None if self._is_form_data: # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 assert ( not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING} & payload.headers.keys() ) # Set default Content-Disposition in case user doesn't create one if CONTENT_DISPOSITION not in payload.headers: name = f"section-{len(self._parts)}" payload.set_content_disposition("form-data", name=name) else: # compression encoding = payload.headers.get(CONTENT_ENCODING, "").lower() if encoding and encoding not in ("deflate", "gzip", "identity"): raise RuntimeError(f"unknown content encoding: {encoding}") if encoding == "identity": encoding = None # te encoding te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() if te_encoding not in ("", "base64", "quoted-printable", "binary"): raise RuntimeError(f"unknown content transfer encoding: {te_encoding}") if te_encoding == "binary": te_encoding = None # size size = payload.size if size is not None and not (encoding or te_encoding): payload.headers[CONTENT_LENGTH] = str(size) self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type] return payload def append_json( self, obj: Any, headers: Mapping[str, str] | None = None ) -> Payload: """Helper to append JSON part.""" if headers is None: headers = CIMultiDict() return self.append_payload(JsonPayload(obj, headers=headers)) def append_form( self, obj: Sequence[tuple[str, str]] | Mapping[str, str], headers: Mapping[str, str] | None = None, ) -> Payload: """Helper to append form urlencoded part.""" assert isinstance(obj, (Sequence, Mapping)) if headers is None: headers = CIMultiDict() if isinstance(obj, Mapping): obj = list(obj.items()) data = urlencode(obj, doseq=True) return self.append_payload( StringPayload( data, headers=headers, content_type="application/x-www-form-urlencoded" ) ) @property def size(self) -> int | None: """Size of the payload.""" total = 0 for part, encoding, te_encoding in self._parts: part_size = part.size if encoding or te_encoding or part_size is None: return None total += int( 2 + len(self._boundary) + 2 + part_size # b'--'+self._boundary+b'\r\n' + len(part._binary_headers) + 2 # b'\r\n' ) total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' return total def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: """Return string representation of the multipart data. WARNING: This method may do blocking I/O if parts contain file payloads. It should not be called in the event loop. Use as_bytes().decode() instead. """ return "".join( "--" + self.boundary + "\r\n" + part._binary_headers.decode(encoding, errors) + part.decode() for part, _e, _te in self._parts ) async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: """Return bytes representation of the multipart data. This method is async-safe and calls as_bytes on underlying payloads. """ parts: list[bytes] = [] # Process each part for part, _e, _te in self._parts: # Add boundary parts.append(b"--" + self._boundary + b"\r\n") # Add headers parts.append(part._binary_headers) # Add payload content using as_bytes for async safety part_bytes = await part.as_bytes(encoding, errors) parts.append(part_bytes) # Add trailing CRLF parts.append(b"\r\n") # Add closing boundary parts.append(b"--" + self._boundary + b"--\r\n") return b"".join(parts) async def write( self, writer: AbstractStreamWriter, close_boundary: bool = True ) -> None: """Write body.""" for part, encoding, te_encoding in self._parts: if self._is_form_data: # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2 assert CONTENT_DISPOSITION in part.headers assert "name=" in part.headers[CONTENT_DISPOSITION] await writer.write(b"--" + self._boundary + b"\r\n") await writer.write(part._binary_headers) if encoding or te_encoding: w = MultipartPayloadWriter(writer) if encoding: w.enable_compression(encoding) if te_encoding: w.enable_encoding(te_encoding) await part.write(w) # type: ignore[arg-type] await w.write_eof() else: await part.write(writer) await writer.write(b"\r\n") if close_boundary: await writer.write(b"--" + self._boundary + b"--\r\n") async def close(self) -> None: """ Close all part payloads that need explicit closing. IMPORTANT: This method must not await anything that might not finish immediately, as it may be called during cleanup/cancellation. Schedule any long-running operations without awaiting them. """ if self._consumed: return self._consumed = True # Close all parts that need explicit closing # We catch and log exceptions to ensure all parts get a chance to close # we do not use asyncio.gather() here because we are not allowed # to suspend given we may be called during cleanup for idx, (part, _, _) in enumerate(self._parts): if not part.autoclose and not part.consumed: try: await part.close() except Exception as exc: internal_logger.error( "Failed to close multipart part %d: %s", idx, exc, exc_info=True ) class MultipartPayloadWriter: def __init__(self, writer: AbstractStreamWriter) -> None: self._writer = writer self._encoding: str | None = None self._compress: ZLibCompressor | None = None self._encoding_buffer: bytearray | None = None def enable_encoding(self, encoding: str) -> None: if encoding == "base64": self._encoding = encoding self._encoding_buffer = bytearray() elif encoding == "quoted-printable": self._encoding = "quoted-printable" def enable_compression( self, encoding: str = "deflate", strategy: int | None = None ) -> None: self._compress = ZLibCompressor( encoding=encoding, suppress_deflate_header=True, strategy=strategy, ) async def write_eof(self) -> None: if self._compress is not None: chunk = self._compress.flush() if chunk: self._compress = None await self.write(chunk) if self._encoding == "base64": if self._encoding_buffer: await self._writer.write(base64.b64encode(self._encoding_buffer)) async def write(self, chunk: bytes) -> None: if self._compress is not None: if chunk: chunk = await self._compress.compress(chunk) if not chunk: return if self._encoding == "base64": buf = self._encoding_buffer assert buf is not None buf.extend(chunk) if buf: div, mod = divmod(len(buf), 3) enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :]) if enc_chunk: b64chunk = base64.b64encode(enc_chunk) await self._writer.write(b64chunk) elif self._encoding == "quoted-printable": await self._writer.write(binascii.b2a_qp(chunk)) else: await self._writer.write(chunk) ================================================ FILE: aiohttp/payload.py ================================================ import asyncio import enum import io import json import mimetypes import os import sys import warnings from abc import ABC, abstractmethod from collections.abc import AsyncIterable, AsyncIterator, Iterable from itertools import chain from typing import IO, Any, Final, TextIO from multidict import CIMultiDict from . import hdrs from .abc import AbstractStreamWriter from .helpers import ( _SENTINEL, content_disposition_header, guess_filename, parse_mimetype, sentinel, ) from .streams import StreamReader from .typedefs import JSONBytesEncoder, JSONEncoder __all__ = ( "PAYLOAD_REGISTRY", "get_payload", "payload_type", "Payload", "BytesPayload", "StringPayload", "IOBasePayload", "BytesIOPayload", "BufferedReaderPayload", "TextIOPayload", "StringIOPayload", "JsonPayload", "JsonBytesPayload", "AsyncIterablePayload", ) TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB READ_SIZE: Final[int] = 2**16 # 64 KB _CLOSE_FUTURES: set[asyncio.Future[None]] = set() class LookupError(Exception): """Raised when no payload factory is found for the given data type.""" class Order(str, enum.Enum): normal = "normal" try_first = "try_first" try_last = "try_last" def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload": return PAYLOAD_REGISTRY.get(data, *args, **kwargs) def register_payload( factory: type["Payload"], type: Any, *, order: Order = Order.normal ) -> None: PAYLOAD_REGISTRY.register(factory, type, order=order) class payload_type: def __init__(self, type: Any, *, order: Order = Order.normal) -> None: self.type = type self.order = order def __call__(self, factory: type["Payload"]) -> type["Payload"]: register_payload(factory, self.type, order=self.order) return factory PayloadType = type["Payload"] _PayloadRegistryItem = tuple[PayloadType, Any] class PayloadRegistry: """Payload registry. note: we need zope.interface for more efficient adapter search """ __slots__ = ("_first", "_normal", "_last", "_normal_lookup") def __init__(self) -> None: self._first: list[_PayloadRegistryItem] = [] self._normal: list[_PayloadRegistryItem] = [] self._last: list[_PayloadRegistryItem] = [] self._normal_lookup: dict[Any, PayloadType] = {} def get( self, data: Any, *args: Any, _CHAIN: "type[chain[_PayloadRegistryItem]]" = chain, **kwargs: Any, ) -> "Payload": if self._first: for factory, type_ in self._first: if isinstance(data, type_): return factory(data, *args, **kwargs) # Try the fast lookup first if lookup_factory := self._normal_lookup.get(type(data)): return lookup_factory(data, *args, **kwargs) # Bail early if its already a Payload if isinstance(data, Payload): return data # Fallback to the slower linear search for factory, type_ in _CHAIN(self._normal, self._last): if isinstance(data, type_): return factory(data, *args, **kwargs) raise LookupError() def register( self, factory: PayloadType, type: Any, *, order: Order = Order.normal ) -> None: if order is Order.try_first: self._first.append((factory, type)) elif order is Order.normal: self._normal.append((factory, type)) if isinstance(type, Iterable): for t in type: self._normal_lookup[t] = factory else: self._normal_lookup[type] = factory elif order is Order.try_last: self._last.append((factory, type)) else: raise ValueError(f"Unsupported order {order!r}") class Payload(ABC): _default_content_type: str = "application/octet-stream" _size: int | None = None _consumed: bool = False # Default: payload has not been consumed yet _autoclose: bool = False # Default: assume resource needs explicit closing def __init__( self, value: Any, headers: ( CIMultiDict[str] | dict[str, str] | Iterable[tuple[str, str]] | None ) = None, content_type: None | str | _SENTINEL = sentinel, filename: str | None = None, encoding: str | None = None, **kwargs: Any, ) -> None: self._encoding = encoding self._filename = filename self._headers = CIMultiDict[str]() self._value = value if content_type is not sentinel and content_type is not None: assert isinstance(content_type, str) self._headers[hdrs.CONTENT_TYPE] = content_type elif self._filename is not None: if sys.version_info >= (3, 13): guesser = mimetypes.guess_file_type else: guesser = mimetypes.guess_type content_type = guesser(self._filename)[0] if content_type is None: content_type = self._default_content_type self._headers[hdrs.CONTENT_TYPE] = content_type else: self._headers[hdrs.CONTENT_TYPE] = self._default_content_type if headers: self._headers.update(headers) @property def size(self) -> int | None: """Size of the payload in bytes. Returns the number of bytes that will be transmitted when the payload is written. For string payloads, this is the size after encoding to bytes, not the length of the string. """ return self._size @property def filename(self) -> str | None: """Filename of the payload.""" return self._filename @property def headers(self) -> CIMultiDict[str]: """Custom item headers""" return self._headers @property def _binary_headers(self) -> bytes: return ( "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode( "utf-8" ) + b"\r\n" ) @property def encoding(self) -> str | None: """Payload encoding""" return self._encoding @property def content_type(self) -> str: """Content type""" return self._headers[hdrs.CONTENT_TYPE] @property def consumed(self) -> bool: """Whether the payload has been consumed and cannot be reused.""" return self._consumed @property def autoclose(self) -> bool: """ Whether the payload can close itself automatically. Returns True if the payload has no file handles or resources that need explicit closing. If False, callers must await close() to release resources. """ return self._autoclose def set_content_disposition( self, disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str, ) -> None: """Sets ``Content-Disposition`` header.""" self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( disptype, quote_fields=quote_fields, _charset=_charset, params=params ) @abstractmethod def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: """ Return string representation of the value. This is named decode() to allow compatibility with bytes objects. """ @abstractmethod async def write(self, writer: AbstractStreamWriter) -> None: """ Write payload to the writer stream. Args: writer: An AbstractStreamWriter instance that handles the actual writing This is a legacy method that writes the entire payload without length constraints. Important: For new implementations, use write_with_length() instead of this method. This method is maintained for backwards compatibility and will eventually delegate to write_with_length(writer, None) in all implementations. All payload subclasses must override this method for backwards compatibility, but new code should use write_with_length for more flexibility and control. """ # write_with_length is new in aiohttp 3.12 # it should be overridden by subclasses async def write_with_length( self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write payload with a specific content length constraint. Args: writer: An AbstractStreamWriter instance that handles the actual writing content_length: Maximum number of bytes to write (None for unlimited) This method allows writing payload content with a specific length constraint, which is particularly useful for HTTP responses with Content-Length header. Note: This is the base implementation that provides backwards compatibility for subclasses that don't override this method. Specific payload types should override this method to implement proper length-constrained writing. """ # Backwards compatibility for subclasses that don't override this method # and for the default implementation await self.write(writer) async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: """ Return bytes representation of the value. This is a convenience method that calls decode() and encodes the result to bytes using the specified encoding. """ # Use instance encoding if available, otherwise use parameter actual_encoding = self._encoding or encoding return self.decode(actual_encoding, errors).encode(actual_encoding) def _close(self) -> None: """ Async safe synchronous close operations for backwards compatibility. This method exists only for backwards compatibility with code that needs to clean up payloads synchronously. In the future, we will drop this method and only support the async close() method. WARNING: This method must be safe to call from within the event loop without blocking. Subclasses should not perform any blocking I/O here. WARNING: This method must be called from within an event loop for certain payload types (e.g., IOBasePayload). Calling it outside an event loop may raise RuntimeError. """ # This is a no-op by default, but subclasses can override it # for non-blocking cleanup operations. async def close(self) -> None: """ Close the payload if it holds any resources. IMPORTANT: This method must not await anything that might not finish immediately, as it may be called during cleanup/cancellation. Schedule any long-running operations without awaiting them. In the future, this will be the only close method supported. """ self._close() class BytesPayload(Payload): _value: bytes # _consumed = False (inherited) - Bytes are immutable and can be reused _autoclose = True # No file handle, just bytes in memory def __init__( self, value: bytes | bytearray | memoryview, *args: Any, **kwargs: Any ) -> None: if "content_type" not in kwargs: kwargs["content_type"] = "application/octet-stream" super().__init__(value, *args, **kwargs) if isinstance(value, memoryview): self._size = value.nbytes elif isinstance(value, (bytes, bytearray)): self._size = len(value) else: raise TypeError(f"value argument must be byte-ish, not {type(value)!r}") if self._size > TOO_LARGE_BYTES_BODY: warnings.warn( "Sending a large body directly with raw bytes might" " lock the event loop. You should probably pass an " "io.BytesIO object instead", ResourceWarning, source=self, ) def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: return self._value.decode(encoding, errors) async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: """ Return bytes representation of the value. This method returns the raw bytes content of the payload. It is equivalent to accessing the _value attribute directly. """ return self._value async def write(self, writer: AbstractStreamWriter) -> None: """ Write the entire bytes payload to the writer stream. Args: writer: An AbstractStreamWriter instance that handles the actual writing This method writes the entire bytes content without any length constraint. Note: For new implementations that need length control, use write_with_length(). This method is maintained for backwards compatibility and is equivalent to write_with_length(writer, None). """ await writer.write(self._value) async def write_with_length( self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write bytes payload with a specific content length constraint. Args: writer: An AbstractStreamWriter instance that handles the actual writing content_length: Maximum number of bytes to write (None for unlimited) This method writes either the entire byte sequence or a slice of it up to the specified content_length. For BytesPayload, this operation is performed efficiently using array slicing. """ if content_length is not None: await writer.write(self._value[:content_length]) else: await writer.write(self._value) class StringPayload(BytesPayload): def __init__( self, value: str, *args: Any, encoding: str | None = None, content_type: str | None = None, **kwargs: Any, ) -> None: if encoding is None: if content_type is None: real_encoding = "utf-8" content_type = "text/plain; charset=utf-8" else: mimetype = parse_mimetype(content_type) real_encoding = mimetype.parameters.get("charset", "utf-8") else: if content_type is None: content_type = "text/plain; charset=%s" % encoding real_encoding = encoding super().__init__( value.encode(real_encoding), encoding=real_encoding, content_type=content_type, *args, **kwargs, ) class StringIOPayload(StringPayload): def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: super().__init__(value.read(), *args, **kwargs) class IOBasePayload(Payload): _value: io.IOBase # _consumed = False (inherited) - File can be re-read from the same position _start_position: int | None = None # _autoclose = False (inherited) - Has file handle that needs explicit closing def __init__( self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any ) -> None: if "filename" not in kwargs: kwargs["filename"] = guess_filename(value) super().__init__(value, *args, **kwargs) if self._filename is not None and disposition is not None: if hdrs.CONTENT_DISPOSITION not in self.headers: self.set_content_disposition(disposition, filename=self._filename) def _set_or_restore_start_position(self) -> None: """Set or restore the start position of the file-like object.""" if self._start_position is None: try: self._start_position = self._value.tell() except (OSError, AttributeError): self._consumed = True # Cannot seek, mark as consumed return try: self._value.seek(self._start_position) except (OSError, AttributeError): # Failed to seek back - mark as consumed since we've already read self._consumed = True def _read_and_available_len( self, remaining_content_len: int | None ) -> tuple[int | None, bytes]: """ Read the file-like object and return both its total size and the first chunk. Args: remaining_content_len: Optional limit on how many bytes to read in this operation. If None, READ_SIZE will be used as the default chunk size. Returns: A tuple containing: - The total size of the remaining unread content (None if size cannot be determined) - The first chunk of bytes read from the file object This method is optimized to perform both size calculation and initial read in a single operation, which is executed in a single executor job to minimize context switches and file operations when streaming content. """ self._set_or_restore_start_position() size = self.size # Call size only once since it does I/O return size, self._value.read( min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE) ) def _read(self, remaining_content_len: int | None) -> bytes: """ Read a chunk of data from the file-like object. Args: remaining_content_len: Optional maximum number of bytes to read. If None, READ_SIZE will be used as the default chunk size. Returns: A chunk of bytes read from the file object, respecting the remaining_content_len limit if specified. This method is used for subsequent reads during streaming after the initial _read_and_available_len call has been made. """ return self._value.read(remaining_content_len or READ_SIZE) # type: ignore[no-any-return] @property def size(self) -> int | None: """ Size of the payload in bytes. Returns the total size of the payload content from the initial position. This ensures consistent Content-Length for requests, including 307/308 redirects where the same payload instance is reused. Returns None if the size cannot be determined (e.g., for unseekable streams). """ try: # Store the start position on first access. # This is critical when the same payload instance is reused (e.g., 307/308 # redirects). Without storing the initial position, after the payload is # read once, the file position would be at EOF, which would cause the # size calculation to return 0 (file_size - EOF position). # By storing the start position, we ensure the size calculation always # returns the correct total size for any subsequent use. if self._start_position is None: self._start_position = self._value.tell() # Return the total size from the start position # This ensures Content-Length is correct even after reading return os.fstat(self._value.fileno()).st_size - self._start_position except (AttributeError, OSError): return None async def write(self, writer: AbstractStreamWriter) -> None: """ Write the entire file-like payload to the writer stream. Args: writer: An AbstractStreamWriter instance that handles the actual writing This method writes the entire file content without any length constraint. It delegates to write_with_length() with no length limit for implementation consistency. Note: For new implementations that need length control, use write_with_length() directly. This method is maintained for backwards compatibility with existing code. """ await self.write_with_length(writer, None) async def write_with_length( self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write file-like payload with a specific content length constraint. Args: writer: An AbstractStreamWriter instance that handles the actual writing content_length: Maximum number of bytes to write (None for unlimited) This method implements optimized streaming of file content with length constraints: 1. File reading is performed in a thread pool to avoid blocking the event loop 2. Content is read and written in chunks to maintain memory efficiency 3. Writing stops when either: - All available file content has been written (when size is known) - The specified content_length has been reached 4. File resources are properly closed even if the operation is cancelled The implementation carefully handles both known-size and unknown-size payloads, as well as constrained and unconstrained content lengths. """ loop = asyncio.get_running_loop() total_written_len = 0 remaining_content_len = content_length # Get initial data and available length available_len, chunk = await loop.run_in_executor( None, self._read_and_available_len, remaining_content_len ) # Process data chunks until done while chunk: chunk_len = len(chunk) # Write data with or without length constraint if remaining_content_len is None: await writer.write(chunk) else: await writer.write(chunk[:remaining_content_len]) remaining_content_len -= chunk_len total_written_len += chunk_len # Check if we're done writing if self._should_stop_writing( available_len, total_written_len, remaining_content_len ): return # Read next chunk chunk = await loop.run_in_executor( None, self._read, ( min(READ_SIZE, remaining_content_len) if remaining_content_len is not None else READ_SIZE ), ) def _should_stop_writing( self, available_len: int | None, total_written_len: int, remaining_content_len: int | None, ) -> bool: """ Determine if we should stop writing data. Args: available_len: Known size of the payload if available (None if unknown) total_written_len: Number of bytes already written remaining_content_len: Remaining bytes to be written for content-length limited responses Returns: True if we should stop writing data, based on either: - Having written all available data (when size is known) - Having written all requested content (when content-length is specified) """ return (available_len is not None and total_written_len >= available_len) or ( remaining_content_len is not None and remaining_content_len <= 0 ) def _close(self) -> None: """ Async safe synchronous close operations for backwards compatibility. This method exists only for backwards compatibility. Use the async close() method instead. WARNING: This method MUST be called from within an event loop. Calling it outside an event loop will raise RuntimeError. """ # Skip if already consumed if self._consumed: return self._consumed = True # Mark as consumed to prevent further writes # Schedule file closing without awaiting to prevent cancellation issues loop = asyncio.get_running_loop() close_future = loop.run_in_executor(None, self._value.close) # Hold a strong reference to the future to prevent it from being # garbage collected before it completes. _CLOSE_FUTURES.add(close_future) close_future.add_done_callback(_CLOSE_FUTURES.remove) async def close(self) -> None: """ Close the payload if it holds any resources. IMPORTANT: This method must not await anything that might not finish immediately, as it may be called during cleanup/cancellation. Schedule any long-running operations without awaiting them. """ self._close() def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: """ Return string representation of the value. WARNING: This method does blocking I/O and should not be called in the event loop. """ return self._read_all().decode(encoding, errors) def _read_all(self) -> bytes: """Read the entire file-like object and return its content as bytes.""" self._set_or_restore_start_position() # Use readlines() to ensure we get all content return b"".join(self._value.readlines()) async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: """ Return bytes representation of the value. This method reads the entire file content and returns it as bytes. It is equivalent to reading the file-like object directly. The file reading is performed in an executor to avoid blocking the event loop. """ loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self._read_all) class TextIOPayload(IOBasePayload): _value: io.TextIOBase # _autoclose = False (inherited) - Has text file handle that needs explicit closing def __init__( self, value: TextIO, *args: Any, encoding: str | None = None, content_type: str | None = None, **kwargs: Any, ) -> None: if encoding is None: if content_type is None: encoding = "utf-8" content_type = "text/plain; charset=utf-8" else: mimetype = parse_mimetype(content_type) encoding = mimetype.parameters.get("charset", "utf-8") else: if content_type is None: content_type = "text/plain; charset=%s" % encoding super().__init__( value, content_type=content_type, encoding=encoding, *args, **kwargs, ) def _read_and_available_len( self, remaining_content_len: int | None ) -> tuple[int | None, bytes]: """ Read the text file-like object and return both its total size and the first chunk. Args: remaining_content_len: Optional limit on how many bytes to read in this operation. If None, READ_SIZE will be used as the default chunk size. Returns: A tuple containing: - The total size of the remaining unread content (None if size cannot be determined) - The first chunk of bytes read from the file object, encoded using the payload's encoding This method is optimized to perform both size calculation and initial read in a single operation, which is executed in a single executor job to minimize context switches and file operations when streaming content. Note: TextIOPayload handles encoding of the text content before writing it to the stream. If no encoding is specified, UTF-8 is used as the default. """ self._set_or_restore_start_position() size = self.size chunk = self._value.read( min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE) ) return size, chunk.encode(self._encoding) if self._encoding else chunk.encode() def _read(self, remaining_content_len: int | None) -> bytes: """ Read a chunk of data from the text file-like object. Args: remaining_content_len: Optional maximum number of bytes to read. If None, READ_SIZE will be used as the default chunk size. Returns: A chunk of bytes read from the file object and encoded using the payload's encoding. The data is automatically converted from text to bytes. This method is used for subsequent reads during streaming after the initial _read_and_available_len call has been made. It properly handles text encoding, converting the text content to bytes using the specified encoding (or UTF-8 if none was provided). """ chunk = self._value.read(remaining_content_len or READ_SIZE) return chunk.encode(self._encoding) if self._encoding else chunk.encode() def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: """ Return string representation of the value. WARNING: This method does blocking I/O and should not be called in the event loop. """ self._set_or_restore_start_position() return self._value.read() async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: """ Return bytes representation of the value. This method reads the entire text file content and returns it as bytes. It encodes the text content using the specified encoding. The file reading is performed in an executor to avoid blocking the event loop. """ loop = asyncio.get_running_loop() # Use instance encoding if available, otherwise use parameter actual_encoding = self._encoding or encoding def _read_and_encode() -> bytes: self._set_or_restore_start_position() # TextIO read() always returns the full content return self._value.read().encode(actual_encoding, errors) return await loop.run_in_executor(None, _read_and_encode) class BytesIOPayload(IOBasePayload): _value: io.BytesIO _size: int # Always initialized in __init__ _autoclose = True # BytesIO is in-memory, safe to auto-close def __init__(self, value: io.BytesIO, *args: Any, **kwargs: Any) -> None: super().__init__(value, *args, **kwargs) # Calculate size once during initialization self._size = len(self._value.getbuffer()) - self._value.tell() @property def size(self) -> int: """Size of the payload in bytes. Returns the number of bytes in the BytesIO buffer that will be transmitted. This is calculated once during initialization for efficiency. """ return self._size def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: self._set_or_restore_start_position() return self._value.read().decode(encoding, errors) async def write(self, writer: AbstractStreamWriter) -> None: return await self.write_with_length(writer, None) async def write_with_length( self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write BytesIO payload with a specific content length constraint. Args: writer: An AbstractStreamWriter instance that handles the actual writing content_length: Maximum number of bytes to write (None for unlimited) This implementation is specifically optimized for BytesIO objects: 1. Reads content in chunks to maintain memory efficiency 2. Yields control back to the event loop periodically to prevent blocking when dealing with large BytesIO objects 3. Respects content_length constraints when specified 4. Properly cleans up by closing the BytesIO object when done or on error The periodic yielding to the event loop is important for maintaining responsiveness when processing large in-memory buffers. """ self._set_or_restore_start_position() loop_count = 0 remaining_bytes = content_length while chunk := self._value.read(READ_SIZE): if loop_count > 0: # Avoid blocking the event loop # if they pass a large BytesIO object # and we are not in the first iteration # of the loop await asyncio.sleep(0) if remaining_bytes is None: await writer.write(chunk) else: await writer.write(chunk[:remaining_bytes]) remaining_bytes -= len(chunk) if remaining_bytes <= 0: return loop_count += 1 async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: """ Return bytes representation of the value. This method reads the entire BytesIO content and returns it as bytes. It is equivalent to accessing the _value attribute directly. """ self._set_or_restore_start_position() return self._value.read() async def close(self) -> None: """ Close the BytesIO payload. This does nothing since BytesIO is in-memory and does not require explicit closing. """ class BufferedReaderPayload(IOBasePayload): _value: io.BufferedIOBase # _autoclose = False (inherited) - Has buffered file handle that needs explicit closing def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: self._set_or_restore_start_position() return self._value.read().decode(encoding, errors) class JsonPayload(BytesPayload): def __init__( self, value: Any, encoding: str = "utf-8", content_type: str = "application/json", dumps: JSONEncoder = json.dumps, *args: Any, **kwargs: Any, ) -> None: super().__init__( dumps(value).encode(encoding), content_type=content_type, encoding=encoding, *args, **kwargs, ) class JsonBytesPayload(BytesPayload): """JSON payload for encoders that return bytes directly. Use this when your JSON encoder (like orjson) returns bytes instead of str, avoiding the encode/decode overhead. """ def __init__( self, value: Any, dumps: JSONBytesEncoder, content_type: str = "application/json", *args: Any, **kwargs: Any, ) -> None: super().__init__( dumps(value), content_type=content_type, *args, **kwargs, ) class AsyncIterablePayload(Payload): _iter: AsyncIterator[bytes] | None = None _value: AsyncIterable[bytes] _cached_chunks: list[bytes] | None = None # _consumed stays False to allow reuse with cached content _autoclose = True # Iterator doesn't need explicit closing def __init__(self, value: AsyncIterable[bytes], *args: Any, **kwargs: Any) -> None: if not isinstance(value, AsyncIterable): raise TypeError( "value argument must support " "collections.abc.AsyncIterable interface, " f"got {type(value)!r}" ) if "content_type" not in kwargs: kwargs["content_type"] = "application/octet-stream" super().__init__(value, *args, **kwargs) self._iter = value.__aiter__() async def write(self, writer: AbstractStreamWriter) -> None: """ Write the entire async iterable payload to the writer stream. Args: writer: An AbstractStreamWriter instance that handles the actual writing This method iterates through the async iterable and writes each chunk to the writer without any length constraint. Note: For new implementations that need length control, use write_with_length() directly. This method is maintained for backwards compatibility with existing code. """ await self.write_with_length(writer, None) async def write_with_length( self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write async iterable payload with a specific content length constraint. Args: writer: An AbstractStreamWriter instance that handles the actual writing content_length: Maximum number of bytes to write (None for unlimited) This implementation handles streaming of async iterable content with length constraints: 1. If cached chunks are available, writes from them 2. Otherwise iterates through the async iterable one chunk at a time 3. Respects content_length constraints when specified 4. Does NOT generate cache - that's done by as_bytes() """ # If we have cached chunks, use them if self._cached_chunks is not None: remaining_bytes = content_length for chunk in self._cached_chunks: if remaining_bytes is None: await writer.write(chunk) elif remaining_bytes > 0: await writer.write(chunk[:remaining_bytes]) remaining_bytes -= len(chunk) else: break return # If iterator is exhausted and we don't have cached chunks, nothing to write if self._iter is None: return # Stream from the iterator remaining_bytes = content_length try: while True: chunk = await anext(self._iter) if remaining_bytes is None: await writer.write(chunk) # If we have a content length limit elif remaining_bytes > 0: await writer.write(chunk[:remaining_bytes]) remaining_bytes -= len(chunk) # We still want to exhaust the iterator even # if we have reached the content length limit # since the file handle may not get closed by # the iterator if we don't do this except StopAsyncIteration: # Iterator is exhausted self._iter = None self._consumed = True # Mark as consumed when streamed without caching def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: """Decode the payload content as a string if cached chunks are available.""" if self._cached_chunks is not None: return b"".join(self._cached_chunks).decode(encoding, errors) raise TypeError("Unable to decode - content not cached. Call as_bytes() first.") async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: """ Return bytes representation of the value. This method reads the entire async iterable content and returns it as bytes. It generates and caches the chunks for future reuse. """ # If we have cached chunks, return them joined if self._cached_chunks is not None: return b"".join(self._cached_chunks) # If iterator is exhausted and no cache, return empty if self._iter is None: return b"" # Read all chunks and cache them chunks: list[bytes] = [] async for chunk in self._iter: chunks.append(chunk) # Iterator is exhausted, cache the chunks self._iter = None self._cached_chunks = chunks # Keep _consumed as False to allow reuse with cached chunks return b"".join(chunks) class StreamReaderPayload(AsyncIterablePayload): def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: super().__init__(value.iter_any(), *args, **kwargs) PAYLOAD_REGISTRY = PayloadRegistry() PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview)) PAYLOAD_REGISTRY.register(StringPayload, str) PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO) PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase) PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO) PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader) # try_last for giving a chance to more specialized async interables like # multipart.BodyPartReaderPayload override the default PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last) ================================================ FILE: aiohttp/py.typed ================================================ Marker ================================================ FILE: aiohttp/pytest_plugin.py ================================================ import asyncio import contextlib import inspect import warnings from collections.abc import Awaitable, Callable, Iterator from typing import Any, Protocol, TypeVar, overload import pytest from .test_utils import ( BaseTestServer, RawTestServer, TestClient, TestServer, loop_context, setup_test_loop, teardown_test_loop, unused_port as _unused_port, ) from .web import Application, BaseRequest, Request from .web_protocol import _RequestHandler try: import uvloop except ImportError: uvloop = None # type: ignore[assignment] _Request = TypeVar("_Request", bound=BaseRequest) class AiohttpClient(Protocol): # TODO(PY311): Use Unpack to specify ClientSession kwargs. @overload async def __call__( self, __param: Application, *, server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[Request, Application]: ... @overload async def __call__( self, __param: BaseTestServer[_Request], *, server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[_Request, None]: ... class AiohttpServer(Protocol): def __call__( self, app: Application, *, port: int | None = None, **kwargs: Any ) -> Awaitable[TestServer]: ... class AiohttpRawServer(Protocol): def __call__( self, handler: _RequestHandler[BaseRequest], *, port: int | None = None, **kwargs: Any, ) -> Awaitable[RawTestServer]: ... def pytest_addoption(parser): # type: ignore[no-untyped-def] parser.addoption( "--aiohttp-fast", action="store_true", default=False, help="run tests faster by disabling extra checks", ) parser.addoption( "--aiohttp-loop", action="store", default="pyloop", help="run tests with specific loop: pyloop, uvloop or all", ) parser.addoption( "--aiohttp-enable-loop-debug", action="store_true", default=False, help="enable event loop debug mode", ) def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def] """Set up pytest fixture. Allow fixtures to be coroutines. Run coroutine fixtures in an event loop. """ func = fixturedef.func if inspect.isasyncgenfunction(func): # async generator fixture is_async_gen = True elif inspect.iscoroutinefunction(func): # regular async fixture is_async_gen = False else: # not an async fixture, nothing to do return strip_request = False if "request" not in fixturedef.argnames: fixturedef.argnames += ("request",) strip_request = True def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] request = kwargs["request"] if strip_request: del kwargs["request"] # if neither the fixture nor the test use the 'loop' fixture, # 'getfixturevalue' will fail because the test is not parameterized # (this can be removed someday if 'loop' is no longer parameterized) if "loop" not in request.fixturenames: raise Exception( "Asynchronous fixtures must depend on the 'loop' fixture or " "be used in tests depending from it." ) _loop = request.getfixturevalue("loop") if is_async_gen: # for async generators, we need to advance the generator once, # then advance it again in a finalizer gen = func(*args, **kwargs) def finalizer(): # type: ignore[no-untyped-def] try: return _loop.run_until_complete(gen.__anext__()) except StopAsyncIteration: pass request.addfinalizer(finalizer) return _loop.run_until_complete(gen.__anext__()) else: return _loop.run_until_complete(func(*args, **kwargs)) fixturedef.func = wrapper @pytest.fixture def fast(request: pytest.FixtureRequest) -> bool: """--fast config option""" return request.config.getoption("--aiohttp-fast") # type: ignore[no-any-return] @pytest.fixture def loop_debug(request: pytest.FixtureRequest) -> bool: """--enable-loop-debug config option""" return request.config.getoption("--aiohttp-enable-loop-debug") # type: ignore[no-any-return] @contextlib.contextmanager def _runtime_warning_context() -> Iterator[None]: """Context manager which checks for RuntimeWarnings. This exists specifically to avoid "coroutine 'X' was never awaited" warnings being missed. If RuntimeWarnings occur in the context a RuntimeError is raised. """ with warnings.catch_warnings(record=True) as _warnings: yield rw = [ f"{w.filename}:{w.lineno}:{w.message}" for w in _warnings if w.category == RuntimeWarning ] if rw: raise RuntimeError( "{} Runtime Warning{},\n{}".format( len(rw), "" if len(rw) == 1 else "s", "\n".join(rw) ) ) # Propagate warnings to pytest for msg in _warnings: warnings.showwarning( msg.message, msg.category, msg.filename, msg.lineno, msg.file, msg.line ) @contextlib.contextmanager def _passthrough_loop_context( loop: asyncio.AbstractEventLoop | None, fast: bool = False ) -> Iterator[asyncio.AbstractEventLoop]: """Passthrough loop context. Sets up and tears down a loop unless one is passed in via the loop argument when it's passed straight through. """ if loop: # loop already exists, pass it straight through yield loop else: # this shadows loop_context's standard behavior loop = setup_test_loop() yield loop teardown_test_loop(loop, fast=fast) def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def] """Fix pytest collecting for coroutines.""" if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj): return list(collector._genfunctions(name, obj)) def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def] """Run coroutines in an event loop instead of a normal function call.""" fast = pyfuncitem.config.getoption("--aiohttp-fast") if inspect.iscoroutinefunction(pyfuncitem.function): existing_loop = ( pyfuncitem.funcargs.get("proactor_loop") or pyfuncitem.funcargs.get("selector_loop") or pyfuncitem.funcargs.get("uvloop_loop") or pyfuncitem.funcargs.get("loop", None) ) with _runtime_warning_context(): with _passthrough_loop_context(existing_loop, fast=fast) as _loop: testargs = { arg: pyfuncitem.funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames } _loop.run_until_complete(pyfuncitem.obj(**testargs)) return True def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def] if "loop_factory" not in metafunc.fixturenames: return loops = metafunc.config.option.aiohttp_loop avail_factories: dict[str, Callable[[], asyncio.AbstractEventLoop]] avail_factories = {"pyloop": asyncio.new_event_loop} if uvloop is not None: avail_factories["uvloop"] = uvloop.new_event_loop if loops == "all": loops = "pyloop,uvloop?" factories = {} # type: ignore[var-annotated] for name in loops.split(","): required = not name.endswith("?") name = name.strip(" ?") if name not in avail_factories: if required: raise ValueError( "Unknown loop '%s', available loops: %s" % (name, list(factories.keys())) ) else: continue factories[name] = avail_factories[name] metafunc.parametrize( "loop_factory", list(factories.values()), ids=list(factories.keys()) ) @pytest.fixture def loop( loop_factory: Callable[[], asyncio.AbstractEventLoop], fast: bool, loop_debug: bool, ) -> Iterator[asyncio.AbstractEventLoop]: """Return an instance of the event loop.""" with loop_context(loop_factory, fast=fast) as _loop: if loop_debug: _loop.set_debug(True) asyncio.set_event_loop(_loop) yield _loop @pytest.fixture def proactor_loop() -> Iterator[asyncio.AbstractEventLoop]: factory = asyncio.ProactorEventLoop # type: ignore[attr-defined] with loop_context(factory) as _loop: asyncio.set_event_loop(_loop) yield _loop @pytest.fixture def aiohttp_unused_port() -> Callable[[], int]: """Return a port that is unused on the current host.""" return _unused_port @pytest.fixture def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]: """Factory to create a TestServer instance, given an app. aiohttp_server(app, **kwargs) """ servers = [] async def go( app: Application, *, host: str = "127.0.0.1", port: int | None = None, **kwargs: Any, ) -> TestServer: server = TestServer(app, host=host, port=port) await server.start_server(**kwargs) servers.append(server) return server yield go async def finalize() -> None: while servers: await servers.pop().close() loop.run_until_complete(finalize()) @pytest.fixture def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]: """Factory to create a RawTestServer instance, given a web handler. aiohttp_raw_server(handler, **kwargs) """ servers = [] async def go( handler: _RequestHandler[BaseRequest], *, port: int | None = None, **kwargs: Any, ) -> RawTestServer: server = RawTestServer(handler, port=port) await server.start_server(**kwargs) servers.append(server) return server yield go async def finalize() -> None: while servers: await servers.pop().close() loop.run_until_complete(finalize()) @pytest.fixture def aiohttp_client_cls() -> type[TestClient[Any, Any]]: """ Client class to use in ``aiohttp_client`` factory. Use it for passing custom ``TestClient`` implementations. Example:: class MyClient(TestClient): async def login(self, *, user, pw): payload = {"username": user, "password": pw} return await self.post("/login", json=payload) @pytest.fixture def aiohttp_client_cls(): return MyClient def test_login(aiohttp_client): app = web.Application() client = await aiohttp_client(app) await client.login(user="admin", pw="s3cr3t") """ return TestClient @pytest.fixture def aiohttp_client( loop: asyncio.AbstractEventLoop, aiohttp_client_cls: type[TestClient[Any, Any]] ) -> Iterator[AiohttpClient]: """Factory to create a TestClient instance. aiohttp_client(app, **kwargs) aiohttp_client(server, **kwargs) aiohttp_client(raw_server, **kwargs) """ clients = [] @overload async def go( __param: Application, *, server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[Request, Application]: ... @overload async def go( __param: BaseTestServer[_Request], *, server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[_Request, None]: ... async def go( __param: Application | BaseTestServer[Any], *, server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[Any, Any]: # TODO(PY311): Use Unpack to specify ClientSession kwargs and server_kwargs. if isinstance(__param, Application): server_kwargs = server_kwargs or {} server = TestServer(__param, **server_kwargs) client = aiohttp_client_cls(server, **kwargs) elif isinstance(__param, BaseTestServer): client = aiohttp_client_cls(__param, **kwargs) else: raise ValueError("Unknown argument type: %r" % type(__param)) await client.start_server() clients.append(client) return client yield go async def finalize() -> None: while clients: await clients.pop().close() loop.run_until_complete(finalize()) ================================================ FILE: aiohttp/resolver.py ================================================ import asyncio import socket import weakref from typing import Any, Optional from .abc import AbstractResolver, ResolveResult __all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") try: import aiodns aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo") except ImportError: aiodns = None # type: ignore[assignment] aiodns_default = False _NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV _NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV _AI_ADDRCONFIG = socket.AI_ADDRCONFIG if hasattr(socket, "AI_MASK"): _AI_ADDRCONFIG &= socket.AI_MASK class ThreadedResolver(AbstractResolver): """Threaded resolver. Uses an Executor for synchronous getaddrinfo() calls. concurrent.futures.ThreadPoolExecutor is used by default. """ def __init__(self) -> None: self._loop = asyncio.get_running_loop() async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET ) -> list[ResolveResult]: infos = await self._loop.getaddrinfo( host, port, type=socket.SOCK_STREAM, family=family, flags=_AI_ADDRCONFIG, ) hosts: list[ResolveResult] = [] for family, _, proto, _, address in infos: if family == socket.AF_INET6: if len(address) < 3: # IPv6 is not supported by Python build, # or IPv6 is not enabled in the host continue if address[3]: # This is essential for link-local IPv6 addresses. # LL IPv6 is a VERY rare case. Strictly speaking, we should use # getnameinfo() unconditionally, but performance makes sense. resolved_host, _port = await self._loop.getnameinfo( address, _NAME_SOCKET_FLAGS ) port = int(_port) else: resolved_host, port = address[:2] else: # IPv4 assert family == socket.AF_INET resolved_host, port = address # type: ignore[misc] hosts.append( ResolveResult( hostname=host, host=resolved_host, port=port, family=family, proto=proto, flags=_NUMERIC_SOCKET_FLAGS, ) ) return hosts async def close(self) -> None: pass class AsyncResolver(AbstractResolver): """Use the `aiodns` package to make asynchronous DNS lookups""" def __init__(self, *args: Any, **kwargs: Any) -> None: if aiodns is None: raise RuntimeError("Resolver requires aiodns library") self._loop = asyncio.get_running_loop() self._manager: _DNSResolverManager | None = None # If custom args are provided, create a dedicated resolver instance # This means each AsyncResolver with custom args gets its own # aiodns.DNSResolver instance if args or kwargs: self._resolver = aiodns.DNSResolver(*args, **kwargs) return # Use the shared resolver from the manager for default arguments self._manager = _DNSResolverManager() self._resolver = self._manager.get_resolver(self, self._loop) async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET ) -> list[ResolveResult]: try: resp = await self._resolver.getaddrinfo( host, port=port, type=socket.SOCK_STREAM, family=family, flags=_AI_ADDRCONFIG, ) except aiodns.error.DNSError as exc: msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" raise OSError(None, msg) from exc hosts: list[ResolveResult] = [] for node in resp.nodes: address: tuple[bytes, int] | tuple[bytes, int, int, int] = node.addr if node.family == socket.AF_INET6: if len(address) > 3 and address[3]: # This is essential for link-local IPv6 addresses. # LL IPv6 is a VERY rare case. Strictly speaking, we should use # getnameinfo() unconditionally, but performance makes sense. result = await self._resolver.getnameinfo( (address[0].decode("ascii"), *address[1:]), _NAME_SOCKET_FLAGS, ) resolved_host = result.node else: resolved_host = address[0].decode("ascii") port = address[1] else: # IPv4 assert node.family == socket.AF_INET resolved_host = address[0].decode("ascii") port = address[1] hosts.append( ResolveResult( hostname=host, host=resolved_host, port=port, family=node.family, proto=0, flags=_NUMERIC_SOCKET_FLAGS, ) ) if not hosts: raise OSError(None, "DNS lookup failed") return hosts async def close(self) -> None: if self._manager: # Release the resolver from the manager if using the shared resolver self._manager.release_resolver(self, self._loop) self._manager = None # Clear reference to manager self._resolver = None # type: ignore[assignment] # Clear reference to resolver return # Otherwise cancel our dedicated resolver if self._resolver is not None: self._resolver.cancel() self._resolver = None # type: ignore[assignment] # Clear reference class _DNSResolverManager: """Manager for aiodns.DNSResolver objects. This class manages shared aiodns.DNSResolver instances with no custom arguments across different event loops. """ _instance: Optional["_DNSResolverManager"] = None def __new__(cls) -> "_DNSResolverManager": if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._init() return cls._instance def _init(self) -> None: # Use WeakKeyDictionary to allow event loops to be garbage collected self._loop_data: weakref.WeakKeyDictionary[ asyncio.AbstractEventLoop, tuple[aiodns.DNSResolver, weakref.WeakSet[AsyncResolver]], ] = weakref.WeakKeyDictionary() def get_resolver( self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop ) -> "aiodns.DNSResolver": """Get or create the shared aiodns.DNSResolver instance for a specific event loop. Args: client: The AsyncResolver instance requesting the resolver. This is required to track resolver usage. loop: The event loop to use for the resolver. """ # Create a new resolver and client set for this loop if it doesn't exist if loop not in self._loop_data: resolver = aiodns.DNSResolver(loop=loop) client_set: weakref.WeakSet[AsyncResolver] = weakref.WeakSet() self._loop_data[loop] = (resolver, client_set) else: # Get the existing resolver and client set resolver, client_set = self._loop_data[loop] # Register this client with the loop client_set.add(client) return resolver def release_resolver( self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop ) -> None: """Release the resolver for an AsyncResolver client when it's closed. Args: client: The AsyncResolver instance to release. loop: The event loop the resolver was using. """ # Remove client from its loop's tracking current_loop_data = self._loop_data.get(loop) if current_loop_data is None: return resolver, client_set = current_loop_data client_set.discard(client) # If no more clients for this loop, cancel and remove its resolver if not client_set: if resolver is not None: resolver.cancel() del self._loop_data[loop] _DefaultType = type[AsyncResolver | ThreadedResolver] DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver ================================================ FILE: aiohttp/streams.py ================================================ import asyncio import collections import warnings from collections.abc import Awaitable, Callable from typing import Final, Generic, TypeVar from .base_protocol import BaseProtocol from .helpers import ( _EXC_SENTINEL, BaseTimerContext, TimerNoop, set_exception, set_result, ) from .http_exceptions import LineTooLong from .log import internal_logger __all__ = ( "EMPTY_PAYLOAD", "EofStream", "StreamReader", "DataQueue", ) _T = TypeVar("_T") class EofStream(Exception): """eof stream indication.""" class AsyncStreamIterator(Generic[_T]): __slots__ = ("read_func",) def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None: self.read_func = read_func def __aiter__(self) -> "AsyncStreamIterator[_T]": return self async def __anext__(self) -> _T: try: rv = await self.read_func() except EofStream: raise StopAsyncIteration if rv == b"": raise StopAsyncIteration return rv class ChunkTupleAsyncStreamIterator: __slots__ = ("_stream",) def __init__(self, stream: "StreamReader") -> None: self._stream = stream def __aiter__(self) -> "ChunkTupleAsyncStreamIterator": return self async def __anext__(self) -> tuple[bytes, bool]: rv = await self._stream.readchunk() if rv == (b"", False): raise StopAsyncIteration return rv class AsyncStreamReaderMixin: __slots__ = () def __aiter__(self) -> AsyncStreamIterator[bytes]: return AsyncStreamIterator(self.readline) # type: ignore[attr-defined] def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]: """Returns an asynchronous iterator that yields chunks of size n.""" return AsyncStreamIterator(lambda: self.read(n)) # type: ignore[attr-defined] def iter_any(self) -> AsyncStreamIterator[bytes]: """Yield all available data as soon as it is received.""" return AsyncStreamIterator(self.readany) # type: ignore[attr-defined] def iter_chunks(self) -> ChunkTupleAsyncStreamIterator: """Yield chunks of data as they are received by the server. The yielded objects are tuples of (bytes, bool) as returned by the StreamReader.readchunk method. """ return ChunkTupleAsyncStreamIterator(self) # type: ignore[arg-type] class StreamReader(AsyncStreamReaderMixin): """An enhancement of asyncio.StreamReader. Supports asynchronous iteration by line, chunk or as available:: async for line in reader: ... async for chunk in reader.iter_chunked(1024): ... async for slice in reader.iter_any(): ... """ __slots__ = ( "_protocol", "_low_water", "_high_water", "_low_water_chunks", "_high_water_chunks", "_loop", "_size", "_cursor", "_http_chunk_splits", "_buffer", "_buffer_offset", "_eof", "_waiter", "_eof_waiter", "_exception", "_timer", "_eof_callbacks", "_eof_counter", "total_bytes", "total_compressed_bytes", ) def __init__( self, protocol: BaseProtocol, limit: int, *, timer: BaseTimerContext | None = None, loop: asyncio.AbstractEventLoop, ) -> None: self._protocol = protocol self._low_water = limit self._high_water = limit * 2 # Ensure high_water_chunks >= 3 so it's always > low_water_chunks. self._high_water_chunks = max(3, limit // 4) # Use max(2, ...) because there's always at least 1 chunk split remaining # (the current position), so we need low_water >= 2 to allow resume. self._low_water_chunks = max(2, self._high_water_chunks // 2) self._loop = loop self._size = 0 self._cursor = 0 self._http_chunk_splits: collections.deque[int] | None = None self._buffer: collections.deque[bytes] = collections.deque() self._buffer_offset = 0 self._eof = False self._waiter: asyncio.Future[None] | None = None self._eof_waiter: asyncio.Future[None] | None = None self._exception: type[BaseException] | BaseException | None = None self._timer = TimerNoop() if timer is None else timer self._eof_callbacks: list[Callable[[], None]] = [] self._eof_counter = 0 self.total_bytes = 0 self.total_compressed_bytes: int | None = None def __repr__(self) -> str: info = [self.__class__.__name__] if self._size: info.append("%d bytes" % self._size) if self._eof: info.append("eof") if self._low_water != 2**16: # default limit info.append("low=%d high=%d" % (self._low_water, self._high_water)) if self._waiter: info.append("w=%r" % self._waiter) if self._exception: info.append("e=%r" % self._exception) return "<%s>" % " ".join(info) def get_read_buffer_limits(self) -> tuple[int, int]: return (self._low_water, self._high_water) def exception(self) -> type[BaseException] | BaseException | None: return self._exception def set_exception( self, exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: self._exception = exc self._eof_callbacks.clear() waiter = self._waiter if waiter is not None: self._waiter = None set_exception(waiter, exc, exc_cause) waiter = self._eof_waiter if waiter is not None: self._eof_waiter = None set_exception(waiter, exc, exc_cause) def on_eof(self, callback: Callable[[], None]) -> None: if self._eof: try: callback() except Exception: internal_logger.exception("Exception in eof callback") else: self._eof_callbacks.append(callback) def feed_eof(self) -> None: self._eof = True waiter = self._waiter if waiter is not None: self._waiter = None set_result(waiter, None) waiter = self._eof_waiter if waiter is not None: self._eof_waiter = None set_result(waiter, None) if self._protocol._reading_paused: self._protocol.resume_reading() for cb in self._eof_callbacks: try: cb() except Exception: internal_logger.exception("Exception in eof callback") self._eof_callbacks.clear() def is_eof(self) -> bool: """Return True if 'feed_eof' was called.""" return self._eof def at_eof(self) -> bool: """Return True if the buffer is empty and 'feed_eof' was called.""" return self._eof and not self._buffer async def wait_eof(self) -> None: if self._eof: return assert self._eof_waiter is None self._eof_waiter = self._loop.create_future() try: await self._eof_waiter finally: self._eof_waiter = None @property def total_raw_bytes(self) -> int: if self.total_compressed_bytes is None: return self.total_bytes return self.total_compressed_bytes def unread_data(self, data: bytes) -> None: """rollback reading some data from stream, inserting it to buffer head.""" warnings.warn( "unread_data() is deprecated " "and will be removed in future releases (#3260)", DeprecationWarning, stacklevel=2, ) if not data: return if self._buffer_offset: self._buffer[0] = self._buffer[0][self._buffer_offset :] self._buffer_offset = 0 self._size += len(data) self._cursor -= len(data) self._buffer.appendleft(data) self._eof_counter = 0 def feed_data(self, data: bytes) -> None: assert not self._eof, "feed_data after feed_eof" if not data: return data_len = len(data) self._size += data_len self._buffer.append(data) self.total_bytes += data_len waiter = self._waiter if waiter is not None: self._waiter = None set_result(waiter, None) if self._size > self._high_water and not self._protocol._reading_paused: self._protocol.pause_reading() def begin_http_chunk_receiving(self) -> None: if self._http_chunk_splits is None: if self.total_bytes: raise RuntimeError( "Called begin_http_chunk_receiving when some data was already fed" ) self._http_chunk_splits = collections.deque() def end_http_chunk_receiving(self) -> None: if self._http_chunk_splits is None: raise RuntimeError( "Called end_chunk_receiving without calling " "begin_chunk_receiving first" ) # self._http_chunk_splits contains logical byte offsets from start of # the body transfer. Each offset is the offset of the end of a chunk. # "Logical" means bytes, accessible for a user. # If no chunks containing logical data were received, current position # is difinitely zero. pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0 if self.total_bytes == pos: # We should not add empty chunks here. So we check for that. # Note, when chunked + gzip is used, we can receive a chunk # of compressed data, but that data may not be enough for gzip FSM # to yield any uncompressed data. That's why current position may # not change after receiving a chunk. return self._http_chunk_splits.append(self.total_bytes) # If we get too many small chunks before self._high_water is reached, then any # .read() call becomes computationally expensive, and could block the event loop # for too long, hence an additional self._high_water_chunks here. if ( len(self._http_chunk_splits) > self._high_water_chunks and not self._protocol._reading_paused ): self._protocol.pause_reading() # wake up readchunk when end of http chunk received waiter = self._waiter if waiter is not None: self._waiter = None set_result(waiter, None) async def _wait(self, func_name: str) -> None: if not self._protocol.connected: raise RuntimeError("Connection closed.") # StreamReader uses a future to link the protocol feed_data() method # to a read coroutine. Running two read coroutines at the same time # would have an unexpected behaviour. It would not possible to know # which coroutine would get the next data. if self._waiter is not None: raise RuntimeError( "%s() called while another coroutine is " "already waiting for incoming data" % func_name ) waiter = self._waiter = self._loop.create_future() try: with self._timer: await waiter finally: self._waiter = None async def readline(self, *, max_line_length: int | None = None) -> bytes: return await self.readuntil(max_size=max_line_length) async def readuntil( self, separator: bytes = b"\n", *, max_size: int | None = None ) -> bytes: seplen = len(separator) if seplen == 0: raise ValueError("Separator should be at least one-byte string") if self._exception is not None: raise self._exception chunk = b"" chunk_size = 0 not_enough = True max_size = max_size or self._high_water while not_enough: while self._buffer and not_enough: offset = self._buffer_offset ichar = self._buffer[0].find(separator, offset) + 1 # Read from current offset to found separator or to the end. data = self._read_nowait_chunk( ichar - offset + seplen - 1 if ichar else -1 ) chunk += data chunk_size += len(data) if ichar: not_enough = False if chunk_size > max_size: raise LineTooLong(chunk[:100] + b"...", max_size) if self._eof: break if not_enough: await self._wait("readuntil") return chunk async def read(self, n: int = -1) -> bytes: if self._exception is not None: raise self._exception if not n: return b"" if n < 0: # This used to just loop creating a new waiter hoping to # collect everything in self._buffer, but that would # deadlock if the subprocess sends more than self.limit # bytes. So just call self.readany() until EOF. blocks = [] while True: block = await self.readany() if not block: break blocks.append(block) return b"".join(blocks) # TODO: should be `if` instead of `while` # because waiter maybe triggered on chunk end, # without feeding any data while not self._buffer and not self._eof: await self._wait("read") return self._read_nowait(n) async def readany(self) -> bytes: if self._exception is not None: raise self._exception # TODO: should be `if` instead of `while` # because waiter maybe triggered on chunk end, # without feeding any data while not self._buffer and not self._eof: await self._wait("readany") return self._read_nowait(-1) async def readchunk(self) -> tuple[bytes, bool]: """Returns a tuple of (data, end_of_http_chunk). When chunked transfer encoding is used, end_of_http_chunk is a boolean indicating if the end of the data corresponds to the end of a HTTP chunk , otherwise it is always False. """ while True: if self._exception is not None: raise self._exception while self._http_chunk_splits: pos = self._http_chunk_splits.popleft() if pos == self._cursor: return (b"", True) if pos > self._cursor: return (self._read_nowait(pos - self._cursor), True) internal_logger.warning( "Skipping HTTP chunk end due to data " "consumption beyond chunk boundary" ) if self._buffer: return (self._read_nowait_chunk(-1), False) # return (self._read_nowait(-1), False) if self._eof: # Special case for signifying EOF. # (b'', True) is not a final return value actually. return (b"", False) await self._wait("readchunk") async def readexactly(self, n: int) -> bytes: if self._exception is not None: raise self._exception blocks: list[bytes] = [] while n > 0: block = await self.read(n) if not block: partial = b"".join(blocks) raise asyncio.IncompleteReadError(partial, len(partial) + n) blocks.append(block) n -= len(block) return b"".join(blocks) def read_nowait(self, n: int = -1) -> bytes: # default was changed to be consistent with .read(-1) # # I believe the most users don't know about the method and # they are not affected. if self._exception is not None: raise self._exception if self._waiter and not self._waiter.done(): raise RuntimeError( "Called while some coroutine is waiting for incoming data." ) return self._read_nowait(n) def _read_nowait_chunk(self, n: int) -> bytes: first_buffer = self._buffer[0] offset = self._buffer_offset if n != -1 and len(first_buffer) - offset > n: data = first_buffer[offset : offset + n] self._buffer_offset += n elif offset: self._buffer.popleft() data = first_buffer[offset:] self._buffer_offset = 0 else: data = self._buffer.popleft() data_len = len(data) self._size -= data_len self._cursor += data_len chunk_splits = self._http_chunk_splits # Prevent memory leak: drop useless chunk splits while chunk_splits and chunk_splits[0] < self._cursor: chunk_splits.popleft() if ( self._protocol._reading_paused and self._size < self._low_water and ( self._http_chunk_splits is None or len(self._http_chunk_splits) < self._low_water_chunks ) ): self._protocol.resume_reading() return data def _read_nowait(self, n: int) -> bytes: """Read not more than n bytes, or whole buffer if n == -1""" self._timer.assert_timeout() chunks = [] while self._buffer: chunk = self._read_nowait_chunk(n) chunks.append(chunk) if n != -1: n -= len(chunk) if n == 0: break return b"".join(chunks) if chunks else b"" class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init] __slots__ = ("_read_eof_chunk",) def __init__(self) -> None: self._read_eof_chunk = False self.total_bytes = 0 def __repr__(self) -> str: return "<%s>" % self.__class__.__name__ def exception(self) -> BaseException | None: return None def set_exception( self, exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: pass def on_eof(self, callback: Callable[[], None]) -> None: try: callback() except Exception: internal_logger.exception("Exception in eof callback") def feed_eof(self) -> None: pass def is_eof(self) -> bool: return True def at_eof(self) -> bool: return True async def wait_eof(self) -> None: return def feed_data(self, data: bytes) -> None: pass async def readline(self, *, max_line_length: int | None = None) -> bytes: return b"" async def read(self, n: int = -1) -> bytes: return b"" # TODO add async def readuntil async def readany(self) -> bytes: return b"" async def readchunk(self) -> tuple[bytes, bool]: if not self._read_eof_chunk: self._read_eof_chunk = True return (b"", False) return (b"", True) async def readexactly(self, n: int) -> bytes: raise asyncio.IncompleteReadError(b"", n) def read_nowait(self, n: int = -1) -> bytes: return b"" EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader() class DataQueue(Generic[_T]): """DataQueue is a general-purpose blocking queue with one reader.""" def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop self._eof = False self._waiter: asyncio.Future[None] | None = None self._exception: type[BaseException] | BaseException | None = None self._buffer: collections.deque[_T] = collections.deque() def __len__(self) -> int: return len(self._buffer) def is_eof(self) -> bool: return self._eof def at_eof(self) -> bool: return self._eof and not self._buffer def exception(self) -> type[BaseException] | BaseException | None: return self._exception def set_exception( self, exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: self._eof = True self._exception = exc if (waiter := self._waiter) is not None: self._waiter = None set_exception(waiter, exc, exc_cause) def feed_data(self, data: _T) -> None: self._buffer.append(data) if (waiter := self._waiter) is not None: self._waiter = None set_result(waiter, None) def feed_eof(self) -> None: self._eof = True if (waiter := self._waiter) is not None: self._waiter = None set_result(waiter, None) async def read(self) -> _T: if not self._buffer and not self._eof: assert not self._waiter self._waiter = self._loop.create_future() try: await self._waiter except (asyncio.CancelledError, asyncio.TimeoutError): self._waiter = None raise if self._buffer: return self._buffer.popleft() if self._exception is not None: raise self._exception raise EofStream def __aiter__(self) -> AsyncStreamIterator[_T]: return AsyncStreamIterator(self.read) ================================================ FILE: aiohttp/tcp_helpers.py ================================================ """Helper methods to tune a TCP connection""" import asyncio import socket from contextlib import suppress __all__ = ("tcp_keepalive", "tcp_nodelay") if hasattr(socket, "SO_KEEPALIVE"): def tcp_keepalive(transport: asyncio.Transport) -> None: sock = transport.get_extra_info("socket") if sock is not None: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) else: def tcp_keepalive(transport: asyncio.Transport) -> None: """Noop when keepalive not supported.""" def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None: sock = transport.get_extra_info("socket") if sock is None: return if sock.family not in (socket.AF_INET, socket.AF_INET6): return value = bool(value) # socket may be closed already, on windows OSError get raised with suppress(OSError): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value) ================================================ FILE: aiohttp/test_utils.py ================================================ """Utilities shared by tests.""" import asyncio import contextlib import gc import ipaddress import os import socket import sys from abc import ABC, abstractmethod from collections.abc import Callable, Iterator from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload from unittest import IsolatedAsyncioTestCase, mock from aiosignal import Signal from multidict import CIMultiDict, CIMultiDictProxy from yarl import URL import aiohttp from aiohttp.client import ( _BaseRequestContextManager, _RequestContextManager, _RequestOptions, _WSRequestContextManager, ) from . import ClientSession, hdrs from .abc import AbstractCookieJar, AbstractStreamWriter from .client_reqrep import ClientResponse from .client_ws import ClientWebSocketResponse from .http import HttpVersion, RawRequestMessage from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import LooseHeaders, StrOrURL from .web import ( Application, AppRunner, BaseRequest, BaseRunner, Request, RequestHandler, Server, ServerRunner, SockSite, UrlMappingMatchInfo, ) from .web_protocol import _RequestHandler if TYPE_CHECKING: from ssl import SSLContext else: SSLContext = Any if sys.version_info >= (3, 11) and TYPE_CHECKING: from typing import Unpack if sys.version_info >= (3, 11): from typing import Self else: Self = Any _ApplicationNone = TypeVar("_ApplicationNone", Application, None) _Request = TypeVar("_Request", bound=BaseRequest) REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" def get_unused_port_socket( host: str, family: socket.AddressFamily = socket.AF_INET ) -> socket.socket: return get_port_socket(host, 0, family) def get_port_socket( host: str, port: int, family: socket.AddressFamily = socket.AF_INET ) -> socket.socket: s = socket.socket(family, socket.SOCK_STREAM) if REUSE_ADDRESS: # Windows has different semantics for SO_REUSEADDR, # so don't set it. Ref: # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((host, port)) return s def unused_port() -> int: """Return a port that is unused on the current host.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return cast(int, s.getsockname()[1]) class BaseTestServer(ABC, Generic[_Request]): __test__ = False def __init__( self, *, scheme: str = "", host: str = "127.0.0.1", port: int | None = None, skip_url_asserts: bool = False, socket_factory: Callable[ [str, int, socket.AddressFamily], socket.socket ] = get_port_socket, **kwargs: Any, ) -> None: self.runner: BaseRunner[_Request] | None = None self._root: URL | None = None self.host = host self.port = port or 0 self._closed = False self.scheme = scheme self.skip_url_asserts = skip_url_asserts self.socket_factory = socket_factory async def start_server(self, **kwargs: Any) -> None: if self.runner: return self._ssl = kwargs.pop("ssl", None) self.runner = await self._make_runner(handler_cancellation=True, **kwargs) await self.runner.setup() absolute_host = self.host try: version = ipaddress.ip_address(self.host).version except ValueError: version = 4 if version == 6: absolute_host = f"[{self.host}]" family = socket.AF_INET6 if version == 6 else socket.AF_INET _sock = self.socket_factory(self.host, self.port, family) self.host, self.port = _sock.getsockname()[:2] site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) await site.start() server = site._server assert server is not None sockets = server.sockets assert sockets is not None self.port = sockets[0].getsockname()[1] if not self.scheme: self.scheme = "https" if self._ssl else "http" self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}") @abstractmethod async def _make_runner(self, **kwargs: Any) -> BaseRunner[_Request]: """Return a new runner for the server.""" # TODO(PY311): Use Unpack to specify Server kwargs. def make_url(self, path: StrOrURL) -> URL: assert self._root is not None url = URL(path) if not self.skip_url_asserts: assert not url.absolute return self._root.join(url) else: return URL(str(self._root) + str(path)) @property def started(self) -> bool: return self.runner is not None @property def closed(self) -> bool: return self._closed @property def handler(self) -> Server[_Request]: # for backward compatibility # web.Server instance runner = self.runner assert runner is not None assert runner.server is not None return runner.server async def close(self) -> None: """Close all fixtures created by the test client. After that point, the TestClient is no longer usable. This is an idempotent function: running close multiple times will not have any additional effects. close is also run when the object is garbage collected, and on exit when used as a context manager. """ if self.started and not self.closed: assert self.runner is not None await self.runner.cleanup() self._root = None self.port = 0 self._closed = True async def __aenter__(self) -> Self: await self.start_server() return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: await self.close() class TestServer(BaseTestServer[Request]): def __init__( self, app: Application, *, scheme: str = "", host: str = "127.0.0.1", port: int | None = None, **kwargs: Any, ): self.app = app super().__init__(scheme=scheme, host=host, port=port, **kwargs) async def _make_runner(self, **kwargs: Any) -> AppRunner: # TODO(PY311): Use Unpack to specify Server kwargs. return AppRunner(self.app, **kwargs) class RawTestServer(BaseTestServer[BaseRequest]): def __init__( self, handler: _RequestHandler[BaseRequest], *, scheme: str = "", host: str = "127.0.0.1", port: int | None = None, **kwargs: Any, ) -> None: self._handler = handler super().__init__(scheme=scheme, host=host, port=port, **kwargs) async def _make_runner(self, **kwargs: Any) -> ServerRunner: # TODO(PY311): Use Unpack to specify Server kwargs. srv = Server(self._handler, **kwargs) return ServerRunner(srv, **kwargs) class TestClient(Generic[_Request, _ApplicationNone]): """ A test client implementation. To write functional tests for aiohttp based servers. """ __test__ = False @overload def __init__( self: "TestClient[Request, Application]", server: TestServer, *, cookie_jar: AbstractCookieJar | None = None, **kwargs: Any, ) -> None: ... @overload def __init__( self: "TestClient[_Request, None]", server: BaseTestServer[_Request], *, cookie_jar: AbstractCookieJar | None = None, **kwargs: Any, ) -> None: ... def __init__( # type: ignore[misc] self, server: BaseTestServer[_Request], *, cookie_jar: AbstractCookieJar | None = None, **kwargs: Any, ) -> None: # TODO(PY311): Use Unpack to specify ClientSession kwargs. if not isinstance(server, BaseTestServer): raise TypeError( "server must be TestServer instance, found type: %r" % type(server) ) self._server = server if cookie_jar is None: cookie_jar = aiohttp.CookieJar(unsafe=True) self._session = ClientSession(cookie_jar=cookie_jar, **kwargs) self._session._retry_connection = False self._closed = False self._responses: list[ClientResponse] = [] self._websockets: list[ClientWebSocketResponse[bool]] = [] async def start_server(self) -> None: await self._server.start_server() @property def scheme(self) -> str | object: return self._server.scheme @property def host(self) -> str: return self._server.host @property def port(self) -> int: return self._server.port @property def server(self) -> BaseTestServer[_Request]: return self._server @property def app(self) -> _ApplicationNone: return getattr(self._server, "app", None) # type: ignore[return-value] @property def session(self) -> ClientSession: """An internal aiohttp.ClientSession. Unlike the methods on the TestClient, client session requests do not automatically include the host in the url queried, and will require an absolute path to the resource. """ return self._session def make_url(self, path: StrOrURL) -> URL: return self._server.make_url(path) async def _request( self, method: str, path: StrOrURL, **kwargs: Any ) -> ClientResponse: resp = await self._session.request(method, self.make_url(path), **kwargs) # save it to close later self._responses.append(resp) return resp if sys.version_info >= (3, 11) and TYPE_CHECKING: def request( self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions] ) -> _RequestContextManager: ... def get( self, path: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> _RequestContextManager: ... def options( self, path: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> _RequestContextManager: ... def head( self, path: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> _RequestContextManager: ... def post( self, path: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> _RequestContextManager: ... def put( self, path: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> _RequestContextManager: ... def patch( self, path: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> _RequestContextManager: ... def delete( self, path: StrOrURL, **kwargs: Unpack[_RequestOptions], ) -> _RequestContextManager: ... else: def request( self, method: str, path: StrOrURL, **kwargs: Any ) -> _RequestContextManager: """Routes a request to tested http server. The interface is identical to aiohttp.ClientSession.request, except the loop kwarg is overridden by the instance used by the test server. """ return _RequestContextManager(self._request(method, path, **kwargs)) def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP GET request.""" return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP POST request.""" return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP OPTIONS request.""" return _RequestContextManager( self._request(hdrs.METH_OPTIONS, path, **kwargs) ) def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP HEAD request.""" return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PUT request.""" return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PATCH request.""" return _RequestContextManager( self._request(hdrs.METH_PATCH, path, **kwargs) ) def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PATCH request.""" return _RequestContextManager( self._request(hdrs.METH_DELETE, path, **kwargs) ) @overload def ws_connect( self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[True]]]": ... @overload def ws_connect( self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any ) -> "_BaseRequestContextManager[ClientWebSocketResponse[Literal[False]]]": ... @overload def ws_connect( self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": ... def ws_connect( self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any ) -> "_BaseRequestContextManager[ClientWebSocketResponse[bool]]": """Initiate websocket connection. The api corresponds to aiohttp.ClientSession.ws_connect. """ return _WSRequestContextManager( self._ws_connect(path, decode_text=decode_text, **kwargs) ) @overload async def _ws_connect( self, path: StrOrURL, *, decode_text: Literal[True] = ..., **kwargs: Any ) -> "ClientWebSocketResponse[Literal[True]]": ... @overload async def _ws_connect( self, path: StrOrURL, *, decode_text: Literal[False], **kwargs: Any ) -> "ClientWebSocketResponse[Literal[False]]": ... @overload async def _ws_connect( self, path: StrOrURL, *, decode_text: bool = ..., **kwargs: Any ) -> "ClientWebSocketResponse[bool]": ... async def _ws_connect( self, path: StrOrURL, *, decode_text: bool = True, **kwargs: Any ) -> "ClientWebSocketResponse[bool]": ws = await self._session.ws_connect( self.make_url(path), decode_text=decode_text, **kwargs ) self._websockets.append(ws) return ws async def close(self) -> None: """Close all fixtures created by the test client. After that point, the TestClient is no longer usable. This is an idempotent function: running close multiple times will not have any additional effects. close is also run on exit when used as a(n) (asynchronous) context manager. """ if not self._closed: for resp in self._responses: resp.close() for ws in self._websockets: await ws.close() await self._session.close() await self._server.close() self._closed = True async def __aenter__(self) -> Self: await self.start_server() return self async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> None: await self.close() class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC): """A base class to allow for unittest web applications using aiohttp. Provides the following: * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. * self.app (aiohttp.web.Application): the application returned by self.get_application() Note that the TestClient's methods are asynchronous: you have to execute function on the test client using asynchronous methods. """ @abstractmethod async def get_application(self) -> Application: """Get application. This method should be overridden to return the aiohttp.web.Application object to test. """ async def asyncSetUp(self) -> None: self.app = await self.get_application() self.server = await self.get_server(self.app) self.client = await self.get_client(self.server) await self.client.start_server() async def asyncTearDown(self) -> None: await self.client.close() async def get_server(self, app: Application) -> TestServer: """Return a TestServer instance.""" return TestServer(app) async def get_client(self, server: TestServer) -> TestClient[Request, Application]: """Return a TestClient instance.""" return TestClient(server) _LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop] @contextlib.contextmanager def loop_context( loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False ) -> Iterator[asyncio.AbstractEventLoop]: """A contextmanager that creates an event_loop, for test purposes. Handles the creation and cleanup of a test loop. """ loop = setup_test_loop(loop_factory) yield loop teardown_test_loop(loop, fast=fast) def setup_test_loop( loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, ) -> asyncio.AbstractEventLoop: """Create and return an asyncio.BaseEventLoop instance. The caller should also call teardown_test_loop, once they are done with the loop. """ loop = loop_factory() asyncio.set_event_loop(loop) return loop def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: """Teardown and cleanup an event_loop created by setup_test_loop.""" closed = loop.is_closed() if not closed: loop.call_soon(loop.stop) loop.run_forever() loop.close() if not fast: gc.collect() asyncio.set_event_loop(None) def _create_app_mock() -> mock.MagicMock: def get_dict(app: Any, key: str) -> Any: return app.__app_dict[key] def set_dict(app: Any, key: str, value: Any) -> None: app.__app_dict[key] = value app = mock.MagicMock(spec=Application) app.__app_dict = {} app.__getitem__ = get_dict app.__setitem__ = set_dict app.on_response_prepare = Signal(app) app.on_response_prepare.freeze() return app def _create_transport(sslcontext: SSLContext | None = None) -> mock.Mock: transport = mock.Mock() def get_extra_info(key: str) -> SSLContext | None: if key == "sslcontext": return sslcontext else: return None transport.get_extra_info.side_effect = get_extra_info return transport def make_mocked_request( method: str, path: str, headers: LooseHeaders | None = None, *, match_info: dict[str, str] | None = None, version: HttpVersion = HttpVersion(1, 1), closing: bool = False, app: Application | None = None, writer: AbstractStreamWriter | None = None, protocol: RequestHandler[Request] | None = None, transport: asyncio.Transport | None = None, payload: StreamReader = EMPTY_PAYLOAD, sslcontext: SSLContext | None = None, client_max_size: int = 1024**2, loop: Any = ..., ) -> Request: """Creates mocked web.Request testing purposes. Useful in unit tests, when spinning full web server is overkill or specific conditions and errors are hard to trigger. """ task = mock.Mock() if loop is ...: # no loop passed, try to get the current one if # its is running as we need a real loop to create # executor jobs to be able to do testing # with a real executor try: loop = asyncio.get_running_loop() except RuntimeError: loop = mock.Mock() loop.create_future.return_value = () if version < HttpVersion(1, 1): closing = True if headers: headers = CIMultiDictProxy(CIMultiDict(headers)) raw_hdrs = tuple( (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() ) else: headers = CIMultiDictProxy(CIMultiDict()) raw_hdrs = () chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() message = RawRequestMessage( method, path, version, headers, raw_hdrs, closing, None, False, chunked, URL(path), ) if app is None: app = _create_app_mock() if transport is None: transport = _create_transport(sslcontext) if protocol is None: protocol = mock.Mock() protocol.max_field_size = 8190 protocol.max_line_length = 8190 protocol.max_headers = 128 protocol.transport = transport type(protocol).peername = mock.PropertyMock( return_value=transport.get_extra_info("peername") ) type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext) if writer is None: writer = mock.Mock() writer.write_headers = mock.AsyncMock(return_value=None) writer.write = mock.AsyncMock(return_value=None) writer.write_eof = mock.AsyncMock(return_value=None) writer.drain = mock.AsyncMock(return_value=None) writer.transport = transport protocol.transport = transport req = Request( message, payload, protocol, writer, task, loop, client_max_size=client_max_size ) match_info = UrlMappingMatchInfo( {} if match_info is None else match_info, mock.Mock() ) match_info.add_app(app) req._match_info = match_info return req ================================================ FILE: aiohttp/tracing.py ================================================ from types import SimpleNamespace from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, overload from aiosignal import Signal from multidict import CIMultiDict from yarl import URL from .client_reqrep import ClientResponse from .helpers import frozen_dataclass_decorator if TYPE_CHECKING: from .client import ClientSession __all__ = ( "TraceConfig", "TraceRequestStartParams", "TraceRequestEndParams", "TraceRequestExceptionParams", "TraceConnectionQueuedStartParams", "TraceConnectionQueuedEndParams", "TraceConnectionCreateStartParams", "TraceConnectionCreateEndParams", "TraceConnectionReuseconnParams", "TraceDnsResolveHostStartParams", "TraceDnsResolveHostEndParams", "TraceDnsCacheHitParams", "TraceDnsCacheMissParams", "TraceRequestRedirectParams", "TraceRequestChunkSentParams", "TraceResponseChunkReceivedParams", "TraceRequestHeadersSentParams", ) _T = TypeVar("_T", covariant=True) _ParamT_contra = TypeVar("_ParamT_contra", contravariant=True) _TracingSignal = Signal["ClientSession", _T, _ParamT_contra] class _Factory(Protocol[_T]): def __call__(self, **kwargs: Any) -> _T: ... class TraceConfig(Generic[_T]): """First-class used to trace requests launched via ClientSession objects.""" @overload def __init__(self: "TraceConfig[SimpleNamespace]") -> None: ... @overload def __init__(self, trace_config_ctx_factory: _Factory[_T]) -> None: ... def __init__( self, trace_config_ctx_factory: _Factory[Any] = SimpleNamespace ) -> None: self._on_request_start: _TracingSignal[_T, TraceRequestStartParams] = Signal( self ) self._on_request_chunk_sent: _TracingSignal[_T, TraceRequestChunkSentParams] = ( Signal(self) ) self._on_response_chunk_received: _TracingSignal[ _T, TraceResponseChunkReceivedParams ] = Signal(self) self._on_request_end: _TracingSignal[_T, TraceRequestEndParams] = Signal(self) self._on_request_exception: _TracingSignal[_T, TraceRequestExceptionParams] = ( Signal(self) ) self._on_request_redirect: _TracingSignal[_T, TraceRequestRedirectParams] = ( Signal(self) ) self._on_connection_queued_start: _TracingSignal[ _T, TraceConnectionQueuedStartParams ] = Signal(self) self._on_connection_queued_end: _TracingSignal[ _T, TraceConnectionQueuedEndParams ] = Signal(self) self._on_connection_create_start: _TracingSignal[ _T, TraceConnectionCreateStartParams ] = Signal(self) self._on_connection_create_end: _TracingSignal[ _T, TraceConnectionCreateEndParams ] = Signal(self) self._on_connection_reuseconn: _TracingSignal[ _T, TraceConnectionReuseconnParams ] = Signal(self) self._on_dns_resolvehost_start: _TracingSignal[ _T, TraceDnsResolveHostStartParams ] = Signal(self) self._on_dns_resolvehost_end: _TracingSignal[ _T, TraceDnsResolveHostEndParams ] = Signal(self) self._on_dns_cache_hit: _TracingSignal[_T, TraceDnsCacheHitParams] = Signal( self ) self._on_dns_cache_miss: _TracingSignal[_T, TraceDnsCacheMissParams] = Signal( self ) self._on_request_headers_sent: _TracingSignal[ _T, TraceRequestHeadersSentParams ] = Signal(self) self._trace_config_ctx_factory: _Factory[_T] = trace_config_ctx_factory def trace_config_ctx(self, trace_request_ctx: Any = None) -> _T: """Return a new trace_config_ctx instance""" return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx) def freeze(self) -> None: self._on_request_start.freeze() self._on_request_chunk_sent.freeze() self._on_response_chunk_received.freeze() self._on_request_end.freeze() self._on_request_exception.freeze() self._on_request_redirect.freeze() self._on_connection_queued_start.freeze() self._on_connection_queued_end.freeze() self._on_connection_create_start.freeze() self._on_connection_create_end.freeze() self._on_connection_reuseconn.freeze() self._on_dns_resolvehost_start.freeze() self._on_dns_resolvehost_end.freeze() self._on_dns_cache_hit.freeze() self._on_dns_cache_miss.freeze() self._on_request_headers_sent.freeze() @property def on_request_start(self) -> "_TracingSignal[_T, TraceRequestStartParams]": return self._on_request_start @property def on_request_chunk_sent( self, ) -> "_TracingSignal[_T, TraceRequestChunkSentParams]": return self._on_request_chunk_sent @property def on_response_chunk_received( self, ) -> "_TracingSignal[_T, TraceResponseChunkReceivedParams]": return self._on_response_chunk_received @property def on_request_end(self) -> "_TracingSignal[_T, TraceRequestEndParams]": return self._on_request_end @property def on_request_exception( self, ) -> "_TracingSignal[_T, TraceRequestExceptionParams]": return self._on_request_exception @property def on_request_redirect( self, ) -> "_TracingSignal[_T, TraceRequestRedirectParams]": return self._on_request_redirect @property def on_connection_queued_start( self, ) -> "_TracingSignal[_T, TraceConnectionQueuedStartParams]": return self._on_connection_queued_start @property def on_connection_queued_end( self, ) -> "_TracingSignal[_T, TraceConnectionQueuedEndParams]": return self._on_connection_queued_end @property def on_connection_create_start( self, ) -> "_TracingSignal[_T, TraceConnectionCreateStartParams]": return self._on_connection_create_start @property def on_connection_create_end( self, ) -> "_TracingSignal[_T, TraceConnectionCreateEndParams]": return self._on_connection_create_end @property def on_connection_reuseconn( self, ) -> "_TracingSignal[_T, TraceConnectionReuseconnParams]": return self._on_connection_reuseconn @property def on_dns_resolvehost_start( self, ) -> "_TracingSignal[_T, TraceDnsResolveHostStartParams]": return self._on_dns_resolvehost_start @property def on_dns_resolvehost_end( self, ) -> "_TracingSignal[_T, TraceDnsResolveHostEndParams]": return self._on_dns_resolvehost_end @property def on_dns_cache_hit(self) -> "_TracingSignal[_T, TraceDnsCacheHitParams]": return self._on_dns_cache_hit @property def on_dns_cache_miss(self) -> "_TracingSignal[_T, TraceDnsCacheMissParams]": return self._on_dns_cache_miss @property def on_request_headers_sent( self, ) -> "_TracingSignal[_T, TraceRequestHeadersSentParams]": return self._on_request_headers_sent @frozen_dataclass_decorator class TraceRequestStartParams: """Parameters sent by the `on_request_start` signal""" method: str url: URL headers: "CIMultiDict[str]" @frozen_dataclass_decorator class TraceRequestChunkSentParams: """Parameters sent by the `on_request_chunk_sent` signal""" method: str url: URL chunk: bytes @frozen_dataclass_decorator class TraceResponseChunkReceivedParams: """Parameters sent by the `on_response_chunk_received` signal""" method: str url: URL chunk: bytes @frozen_dataclass_decorator class TraceRequestEndParams: """Parameters sent by the `on_request_end` signal""" method: str url: URL headers: "CIMultiDict[str]" response: ClientResponse @frozen_dataclass_decorator class TraceRequestExceptionParams: """Parameters sent by the `on_request_exception` signal""" method: str url: URL headers: "CIMultiDict[str]" exception: BaseException @frozen_dataclass_decorator class TraceRequestRedirectParams: """Parameters sent by the `on_request_redirect` signal""" method: str url: URL headers: "CIMultiDict[str]" response: ClientResponse @frozen_dataclass_decorator class TraceConnectionQueuedStartParams: """Parameters sent by the `on_connection_queued_start` signal""" @frozen_dataclass_decorator class TraceConnectionQueuedEndParams: """Parameters sent by the `on_connection_queued_end` signal""" @frozen_dataclass_decorator class TraceConnectionCreateStartParams: """Parameters sent by the `on_connection_create_start` signal""" @frozen_dataclass_decorator class TraceConnectionCreateEndParams: """Parameters sent by the `on_connection_create_end` signal""" @frozen_dataclass_decorator class TraceConnectionReuseconnParams: """Parameters sent by the `on_connection_reuseconn` signal""" @frozen_dataclass_decorator class TraceDnsResolveHostStartParams: """Parameters sent by the `on_dns_resolvehost_start` signal""" host: str @frozen_dataclass_decorator class TraceDnsResolveHostEndParams: """Parameters sent by the `on_dns_resolvehost_end` signal""" host: str @frozen_dataclass_decorator class TraceDnsCacheHitParams: """Parameters sent by the `on_dns_cache_hit` signal""" host: str @frozen_dataclass_decorator class TraceDnsCacheMissParams: """Parameters sent by the `on_dns_cache_miss` signal""" host: str @frozen_dataclass_decorator class TraceRequestHeadersSentParams: """Parameters sent by the `on_request_headers_sent` signal""" method: str url: URL headers: "CIMultiDict[str]" class Trace: """Internal dependency holder class. Used to keep together the main dependencies used at the moment of send a signal. """ def __init__( self, session: "ClientSession", trace_config: TraceConfig[object], trace_config_ctx: Any, ) -> None: self._trace_config = trace_config self._trace_config_ctx = trace_config_ctx self._session = session async def send_request_start( self, method: str, url: URL, headers: "CIMultiDict[str]" ) -> None: return await self._trace_config.on_request_start.send( self._session, self._trace_config_ctx, TraceRequestStartParams(method, url, headers), ) async def send_request_chunk_sent( self, method: str, url: URL, chunk: bytes ) -> None: return await self._trace_config.on_request_chunk_sent.send( self._session, self._trace_config_ctx, TraceRequestChunkSentParams(method, url, chunk), ) async def send_response_chunk_received( self, method: str, url: URL, chunk: bytes ) -> None: return await self._trace_config.on_response_chunk_received.send( self._session, self._trace_config_ctx, TraceResponseChunkReceivedParams(method, url, chunk), ) async def send_request_end( self, method: str, url: URL, headers: "CIMultiDict[str]", response: ClientResponse, ) -> None: return await self._trace_config.on_request_end.send( self._session, self._trace_config_ctx, TraceRequestEndParams(method, url, headers, response), ) async def send_request_exception( self, method: str, url: URL, headers: "CIMultiDict[str]", exception: BaseException, ) -> None: return await self._trace_config.on_request_exception.send( self._session, self._trace_config_ctx, TraceRequestExceptionParams(method, url, headers, exception), ) async def send_request_redirect( self, method: str, url: URL, headers: "CIMultiDict[str]", response: ClientResponse, ) -> None: return await self._trace_config._on_request_redirect.send( self._session, self._trace_config_ctx, TraceRequestRedirectParams(method, url, headers, response), ) async def send_connection_queued_start(self) -> None: return await self._trace_config.on_connection_queued_start.send( self._session, self._trace_config_ctx, TraceConnectionQueuedStartParams() ) async def send_connection_queued_end(self) -> None: return await self._trace_config.on_connection_queued_end.send( self._session, self._trace_config_ctx, TraceConnectionQueuedEndParams() ) async def send_connection_create_start(self) -> None: return await self._trace_config.on_connection_create_start.send( self._session, self._trace_config_ctx, TraceConnectionCreateStartParams() ) async def send_connection_create_end(self) -> None: return await self._trace_config.on_connection_create_end.send( self._session, self._trace_config_ctx, TraceConnectionCreateEndParams() ) async def send_connection_reuseconn(self) -> None: return await self._trace_config.on_connection_reuseconn.send( self._session, self._trace_config_ctx, TraceConnectionReuseconnParams() ) async def send_dns_resolvehost_start(self, host: str) -> None: return await self._trace_config.on_dns_resolvehost_start.send( self._session, self._trace_config_ctx, TraceDnsResolveHostStartParams(host) ) async def send_dns_resolvehost_end(self, host: str) -> None: return await self._trace_config.on_dns_resolvehost_end.send( self._session, self._trace_config_ctx, TraceDnsResolveHostEndParams(host) ) async def send_dns_cache_hit(self, host: str) -> None: return await self._trace_config.on_dns_cache_hit.send( self._session, self._trace_config_ctx, TraceDnsCacheHitParams(host) ) async def send_dns_cache_miss(self, host: str) -> None: return await self._trace_config.on_dns_cache_miss.send( self._session, self._trace_config_ctx, TraceDnsCacheMissParams(host) ) async def send_request_headers( self, method: str, url: URL, headers: "CIMultiDict[str]" ) -> None: return await self._trace_config._on_request_headers_sent.send( self._session, self._trace_config_ctx, TraceRequestHeadersSentParams(method, url, headers), ) ================================================ FILE: aiohttp/typedefs.py ================================================ import json import os from collections.abc import Awaitable, Callable, Iterable, Mapping from http.cookies import BaseCookie, Morsel from typing import TYPE_CHECKING, Any, Protocol from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL DEFAULT_JSON_ENCODER = json.dumps DEFAULT_JSON_DECODER = json.loads if TYPE_CHECKING: from .web import Request, StreamResponse Byteish = bytes | bytearray | memoryview JSONEncoder = Callable[[Any], str] JSONBytesEncoder = Callable[[Any], bytes] JSONDecoder = Callable[[str], Any] LooseHeaders = ( Mapping[str, str] | Mapping[istr, str] | CIMultiDict[str] | CIMultiDictProxy[str] | Iterable[tuple[str | istr, str]] ) RawHeaders = tuple[tuple[bytes, bytes], ...] StrOrURL = str | URL LooseCookiesMappings = Mapping[str, str | BaseCookie[str] | Morsel[Any]] LooseCookiesIterables = Iterable[tuple[str, str | BaseCookie[str] | Morsel[Any]]] LooseCookies = LooseCookiesMappings | LooseCookiesIterables | BaseCookie[str] Handler = Callable[["Request"], Awaitable["StreamResponse"]] class Middleware(Protocol): def __call__( self, request: "Request", handler: Handler ) -> Awaitable["StreamResponse"]: ... PathLike = str | os.PathLike[str] ================================================ FILE: aiohttp/web.py ================================================ import asyncio import logging import os import socket import sys import warnings from argparse import ArgumentParser from collections.abc import Awaitable, Callable, Iterable, Iterable as TypingIterable from contextlib import suppress from importlib import import_module from typing import Any, cast from .abc import AbstractAccessLogger from .helpers import AppKey, RequestKey, ResponseKey from .log import access_logger from .typedefs import PathLike from .web_app import Application, CleanupError from .web_exceptions import ( HTTPAccepted, HTTPBadGateway, HTTPBadRequest, HTTPClientError, HTTPConflict, HTTPCreated, HTTPError, HTTPException, HTTPExpectationFailed, HTTPFailedDependency, HTTPForbidden, HTTPFound, HTTPGatewayTimeout, HTTPGone, HTTPInsufficientStorage, HTTPInternalServerError, HTTPLengthRequired, HTTPMethodNotAllowed, HTTPMisdirectedRequest, HTTPMove, HTTPMovedPermanently, HTTPMultipleChoices, HTTPNetworkAuthenticationRequired, HTTPNoContent, HTTPNonAuthoritativeInformation, HTTPNotAcceptable, HTTPNotExtended, HTTPNotFound, HTTPNotImplemented, HTTPNotModified, HTTPOk, HTTPPartialContent, HTTPPaymentRequired, HTTPPermanentRedirect, HTTPPreconditionFailed, HTTPPreconditionRequired, HTTPProxyAuthenticationRequired, HTTPRedirection, HTTPRequestEntityTooLarge, HTTPRequestHeaderFieldsTooLarge, HTTPRequestRangeNotSatisfiable, HTTPRequestTimeout, HTTPRequestURITooLong, HTTPResetContent, HTTPSeeOther, HTTPServerError, HTTPServiceUnavailable, HTTPSuccessful, HTTPTemporaryRedirect, HTTPTooManyRequests, HTTPUnauthorized, HTTPUnavailableForLegalReasons, HTTPUnprocessableEntity, HTTPUnsupportedMediaType, HTTPUpgradeRequired, HTTPUseProxy, HTTPVariantAlsoNegotiates, HTTPVersionNotSupported, NotAppKeyWarning, ) from .web_fileresponse import FileResponse from .web_log import AccessLogger from .web_middlewares import middleware, normalize_path_middleware from .web_protocol import PayloadAccessError, RequestHandler, RequestPayloadError from .web_request import BaseRequest, FileField, Request from .web_response import ( ContentCoding, Response, StreamResponse, json_bytes_response, json_response, ) from .web_routedef import ( AbstractRouteDef, RouteDef, RouteTableDef, StaticDef, delete, get, head, options, patch, post, put, route, static, view, ) from .web_runner import ( AppRunner, BaseRunner, BaseSite, GracefulExit, NamedPipeSite, ServerRunner, SockSite, TCPSite, UnixSite, ) from .web_server import Server from .web_urldispatcher import ( AbstractResource, AbstractRoute, DynamicResource, PlainResource, PrefixedSubAppResource, Resource, ResourceRoute, StaticResource, UrlDispatcher, UrlMappingMatchInfo, View, ) from .web_ws import WebSocketReady, WebSocketResponse, WSMsgType __all__ = ( # web_app "AppKey", "Application", "CleanupError", # web_exceptions "NotAppKeyWarning", "HTTPAccepted", "HTTPBadGateway", "HTTPBadRequest", "HTTPClientError", "HTTPConflict", "HTTPCreated", "HTTPError", "HTTPException", "HTTPExpectationFailed", "HTTPFailedDependency", "HTTPForbidden", "HTTPFound", "HTTPGatewayTimeout", "HTTPGone", "HTTPInsufficientStorage", "HTTPInternalServerError", "HTTPLengthRequired", "HTTPMethodNotAllowed", "HTTPMisdirectedRequest", "HTTPMove", "HTTPMovedPermanently", "HTTPMultipleChoices", "HTTPNetworkAuthenticationRequired", "HTTPNoContent", "HTTPNonAuthoritativeInformation", "HTTPNotAcceptable", "HTTPNotExtended", "HTTPNotFound", "HTTPNotImplemented", "HTTPNotModified", "HTTPOk", "HTTPPartialContent", "HTTPPaymentRequired", "HTTPPermanentRedirect", "HTTPPreconditionFailed", "HTTPPreconditionRequired", "HTTPProxyAuthenticationRequired", "HTTPRedirection", "HTTPRequestEntityTooLarge", "HTTPRequestHeaderFieldsTooLarge", "HTTPRequestRangeNotSatisfiable", "HTTPRequestTimeout", "HTTPRequestURITooLong", "HTTPResetContent", "HTTPSeeOther", "HTTPServerError", "HTTPServiceUnavailable", "HTTPSuccessful", "HTTPTemporaryRedirect", "HTTPTooManyRequests", "HTTPUnauthorized", "HTTPUnavailableForLegalReasons", "HTTPUnprocessableEntity", "HTTPUnsupportedMediaType", "HTTPUpgradeRequired", "HTTPUseProxy", "HTTPVariantAlsoNegotiates", "HTTPVersionNotSupported", # web_fileresponse "FileResponse", # web_middlewares "middleware", "normalize_path_middleware", # web_protocol "PayloadAccessError", "RequestHandler", "RequestPayloadError", # web_request "BaseRequest", "FileField", "Request", "RequestKey", # web_response "ContentCoding", "Response", "StreamResponse", "json_bytes_response", "json_response", "ResponseKey", # web_routedef "AbstractRouteDef", "RouteDef", "RouteTableDef", "StaticDef", "delete", "get", "head", "options", "patch", "post", "put", "route", "static", "view", # web_runner "AppRunner", "BaseRunner", "BaseSite", "GracefulExit", "ServerRunner", "SockSite", "TCPSite", "UnixSite", "NamedPipeSite", # web_server "Server", # web_urldispatcher "AbstractResource", "AbstractRoute", "DynamicResource", "PlainResource", "PrefixedSubAppResource", "Resource", "ResourceRoute", "StaticResource", "UrlDispatcher", "UrlMappingMatchInfo", "View", # web_ws "WebSocketReady", "WebSocketResponse", "WSMsgType", # web "run_app", ) try: from ssl import SSLContext except ImportError: # pragma: no cover SSLContext = object # type: ignore[misc,assignment] # Only display warning when using -Wdefault, -We, -X dev or similar. warnings.filterwarnings("ignore", category=NotAppKeyWarning, append=True) HostSequence = TypingIterable[str] async def _run_app( app: Application | Awaitable[Application], *, host: str | HostSequence | None = None, port: int | None = None, path: PathLike | TypingIterable[PathLike] | None = None, sock: socket.socket | TypingIterable[socket.socket] | None = None, ssl_context: SSLContext | None = None, print: Callable[..., None] | None = print, backlog: int = 128, reuse_address: bool | None = None, reuse_port: bool | None = None, **kwargs: Any, # TODO(PY311): Use Unpack ) -> None: # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): app = await app app = cast(Application, app) runner = AppRunner(app, **kwargs) await runner.setup() sites: list[BaseSite] = [] try: if host is not None: if isinstance(host, str): sites.append( TCPSite( runner, host, port, ssl_context=ssl_context, backlog=backlog, reuse_address=reuse_address, reuse_port=reuse_port, ) ) else: for h in host: sites.append( TCPSite( runner, h, port, ssl_context=ssl_context, backlog=backlog, reuse_address=reuse_address, reuse_port=reuse_port, ) ) elif path is None and sock is None or port is not None: sites.append( TCPSite( runner, port=port, ssl_context=ssl_context, backlog=backlog, reuse_address=reuse_address, reuse_port=reuse_port, ) ) if path is not None: if isinstance(path, (str, os.PathLike)): sites.append( UnixSite( runner, path, ssl_context=ssl_context, backlog=backlog, ) ) else: for p in path: sites.append( UnixSite( runner, p, ssl_context=ssl_context, backlog=backlog, ) ) if sock is not None: if not isinstance(sock, Iterable): sites.append( SockSite( runner, sock, ssl_context=ssl_context, backlog=backlog, ) ) else: for s in sock: sites.append( SockSite( runner, s, ssl_context=ssl_context, backlog=backlog, ) ) for site in sites: await site.start() if print: # pragma: no branch names = sorted(str(s.name) for s in runner.sites) print( "======== Running on {} ========\n" "(Press CTRL+C to quit)".format(", ".join(names)) ) # sleep forever by 1 hour intervals, while True: await asyncio.sleep(3600) finally: await runner.cleanup() def _cancel_tasks( to_cancel: set["asyncio.Task[Any]"], loop: asyncio.AbstractEventLoop ) -> None: if not to_cancel: return for task in to_cancel: task.cancel() loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) for task in to_cancel: if task.cancelled(): continue if task.exception() is not None: loop.call_exception_handler( { "message": "unhandled exception during asyncio.run() shutdown", "exception": task.exception(), "task": task, } ) def run_app( app: Application | Awaitable[Application], *, debug: bool = False, host: str | HostSequence | None = None, port: int | None = None, path: PathLike | TypingIterable[PathLike] | None = None, sock: socket.socket | TypingIterable[socket.socket] | None = None, shutdown_timeout: float = 60.0, keepalive_timeout: float = 75.0, ssl_context: SSLContext | None = None, print: Callable[..., None] | None = print, backlog: int = 128, access_log_class: type[AbstractAccessLogger] = AccessLogger, access_log_format: str = AccessLogger.LOG_FORMAT, access_log: logging.Logger | None = access_logger, handle_signals: bool = True, reuse_address: bool | None = None, reuse_port: bool | None = None, handler_cancellation: bool = False, loop: asyncio.AbstractEventLoop | None = None, **kwargs: Any, ) -> None: """Run an app locally""" if loop is None: loop = asyncio.new_event_loop() loop.set_debug(debug) # Configure if and only if in debugging mode and using the default logger if loop.get_debug() and access_log and access_log.name == "aiohttp.access": if access_log.level == logging.NOTSET: access_log.setLevel(logging.DEBUG) if not access_log.hasHandlers(): access_log.addHandler(logging.StreamHandler()) main_task = loop.create_task( _run_app( app, host=host, port=port, path=path, sock=sock, shutdown_timeout=shutdown_timeout, keepalive_timeout=keepalive_timeout, ssl_context=ssl_context, print=print, backlog=backlog, access_log_class=access_log_class, access_log_format=access_log_format, access_log=access_log, handle_signals=handle_signals, reuse_address=reuse_address, reuse_port=reuse_port, handler_cancellation=handler_cancellation, **kwargs, ) ) try: asyncio.set_event_loop(loop) loop.run_until_complete(main_task) except (GracefulExit, KeyboardInterrupt): pass finally: try: main_task.cancel() with suppress(asyncio.CancelledError): loop.run_until_complete(main_task) finally: _cancel_tasks(asyncio.all_tasks(loop), loop) loop.run_until_complete(loop.shutdown_asyncgens()) loop.close() asyncio.set_event_loop(None) def main(argv: list[str]) -> None: arg_parser = ArgumentParser( description="aiohttp.web Application server", prog="aiohttp.web" ) arg_parser.add_argument( "entry_func", help=( "Callable returning the `aiohttp.web.Application` instance to " "run. Should be specified in the 'module:function' syntax." ), metavar="entry-func", ) arg_parser.add_argument( "-H", "--hostname", help="TCP/IP hostname to serve on (default: localhost)", default=None, ) arg_parser.add_argument( "-P", "--port", help="TCP/IP port to serve on (default: %(default)r)", type=int, default=8080, ) arg_parser.add_argument( "-U", "--path", help="Unix file system path to serve on. Can be combined with hostname " "to serve on both Unix and TCP.", ) args, extra_argv = arg_parser.parse_known_args(argv) # Import logic mod_str, _, func_str = args.entry_func.partition(":") if not func_str or not mod_str: arg_parser.error("'entry-func' not in 'module:function' syntax") if mod_str.startswith("."): arg_parser.error("relative module names not supported") try: module = import_module(mod_str) except ImportError as ex: arg_parser.error(f"unable to import {mod_str}: {ex}") try: func = getattr(module, func_str) except AttributeError: arg_parser.error(f"module {mod_str!r} has no attribute {func_str!r}") # Compatibility logic if args.path is not None and not hasattr(socket, "AF_UNIX"): arg_parser.error( "file system paths not supported by your operating environment" ) logging.basicConfig(level=logging.DEBUG) if args.path and args.hostname is None: host = port = None else: host = args.hostname or "localhost" port = args.port app = func(extra_argv) run_app(app, host=host, port=port, path=args.path) arg_parser.exit(message="Stopped\n") if __name__ == "__main__": # pragma: no branch main(sys.argv[1:]) # pragma: no cover ================================================ FILE: aiohttp/web_app.py ================================================ import asyncio import logging import warnings from collections.abc import ( AsyncIterator, Awaitable, Callable, Iterable, Iterator, Mapping, MutableMapping, Sequence, ) from contextlib import AbstractAsyncContextManager, asynccontextmanager from functools import lru_cache, partial, update_wrapper from typing import Any, TypeVar, cast, final, overload from aiosignal import Signal from frozenlist import FrozenList from . import hdrs from .helpers import AppKey from .log import web_logger from .typedefs import Handler, Middleware from .web_exceptions import NotAppKeyWarning from .web_middlewares import _fix_request_current_app from .web_request import Request from .web_response import StreamResponse from .web_routedef import AbstractRouteDef from .web_urldispatcher import ( AbstractResource, AbstractRoute, Domain, MaskDomain, MatchedSubAppResource, PrefixedSubAppResource, SystemRoute, UrlDispatcher, ) __all__ = ("Application", "CleanupError") _AppSignal = Signal["Application"] _RespPrepareSignal = Signal[Request, StreamResponse] _Middlewares = FrozenList[Middleware] _MiddlewaresHandlers = Sequence[Middleware] _Subapps = list["Application"] _T = TypeVar("_T") _U = TypeVar("_U") _Resource = TypeVar("_Resource", bound=AbstractResource) def _build_middlewares( handler: Handler, apps: tuple["Application", ...] ) -> Callable[[Request], Awaitable[StreamResponse]]: """Apply middlewares to handler.""" # The slice is to reverse the order of the apps # so they are applied in the order they were added for app in apps[::-1]: assert app.pre_frozen, "middleware handlers are not ready" for m in app._middlewares_handlers: handler = update_wrapper(partial(m, handler=handler), handler) return handler _cached_build_middleware = lru_cache(maxsize=1024)(_build_middlewares) @final class Application(MutableMapping[str | AppKey[Any], Any]): __slots__ = ( "logger", "_router", "_loop", "_handler_args", "_middlewares", "_middlewares_handlers", "_run_middlewares", "_state", "_frozen", "_pre_frozen", "_subapps", "_on_response_prepare", "_on_startup", "_on_shutdown", "_on_cleanup", "_client_max_size", "_cleanup_ctx", ) def __init__( self, *, logger: logging.Logger = web_logger, middlewares: Iterable[Middleware] = (), handler_args: Mapping[str, Any] | None = None, client_max_size: int = 1024**2, debug: Any = ..., # mypy doesn't support ellipsis ) -> None: if debug is not ...: warnings.warn( "debug argument is no-op since 4.0 and scheduled for removal in 5.0", DeprecationWarning, stacklevel=2, ) self._router = UrlDispatcher() self._handler_args = handler_args self.logger = logger self._middlewares: _Middlewares = FrozenList(middlewares) # initialized on freezing self._middlewares_handlers: _MiddlewaresHandlers = tuple() # initialized on freezing self._run_middlewares: bool | None = None self._state: dict[AppKey[Any] | str, object] = {} self._frozen = False self._pre_frozen = False self._subapps: _Subapps = [] self._on_response_prepare: _RespPrepareSignal = Signal(self) self._on_startup: _AppSignal = Signal(self) self._on_shutdown: _AppSignal = Signal(self) self._on_cleanup: _AppSignal = Signal(self) self._cleanup_ctx = CleanupContext() self._on_startup.append(self._cleanup_ctx._on_startup) self._on_cleanup.append(self._cleanup_ctx._on_cleanup) self._client_max_size = client_max_size def __init_subclass__(cls: type["Application"]) -> None: raise TypeError( f"Inheritance class {cls.__name__} from web.Application is forbidden" ) # MutableMapping API def __eq__(self, other: object) -> bool: return self is other @overload # type: ignore[override] def __getitem__(self, key: AppKey[_T]) -> _T: ... @overload def __getitem__(self, key: str) -> Any: ... def __getitem__(self, key: str | AppKey[_T]) -> Any: return self._state[key] def _check_frozen(self) -> None: if self._frozen: raise RuntimeError( "Changing state of started or joined application is forbidden" ) @overload # type: ignore[override] def __setitem__(self, key: AppKey[_T], value: _T) -> None: ... @overload def __setitem__(self, key: str, value: Any) -> None: ... def __setitem__(self, key: str | AppKey[_T], value: Any) -> None: self._check_frozen() if not isinstance(key, AppKey): warnings.warn( "It is recommended to use web.AppKey instances for keys.\n" + "https://docs.aiohttp.org/en/stable/web_advanced.html" + "#application-s-config", category=NotAppKeyWarning, stacklevel=2, ) self._state[key] = value def __delitem__(self, key: str | AppKey[_T]) -> None: self._check_frozen() del self._state[key] def __len__(self) -> int: return len(self._state) def __iter__(self) -> Iterator[str | AppKey[Any]]: return iter(self._state) def __hash__(self) -> int: return id(self) @overload # type: ignore[override] def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ... @overload def get(self, key: AppKey[_T], default: _U) -> _T | _U: ... @overload def get(self, key: str, default: Any = ...) -> Any: ... def get(self, key: str | AppKey[_T], default: Any = None) -> Any: return self._state.get(key, default) ######## def _set_loop(self, loop: asyncio.AbstractEventLoop | None) -> None: warnings.warn( "_set_loop() is no-op since 4.0 and scheduled for removal in 5.0", DeprecationWarning, stacklevel=2, ) @property def pre_frozen(self) -> bool: return self._pre_frozen def pre_freeze(self) -> None: if self._pre_frozen: return self._pre_frozen = True self._middlewares.freeze() self._router.freeze() self._on_response_prepare.freeze() self._cleanup_ctx.freeze() self._on_startup.freeze() self._on_shutdown.freeze() self._on_cleanup.freeze() self._middlewares_handlers = tuple(self._prepare_middleware()) # If current app and any subapp do not have middlewares avoid run all # of the code footprint that it implies, which have a middleware # hardcoded per app that sets up the current_app attribute. If no # middlewares are configured the handler will receive the proper # current_app without needing all of this code. self._run_middlewares = True if self.middlewares else False for subapp in self._subapps: subapp.pre_freeze() self._run_middlewares = self._run_middlewares or subapp._run_middlewares @property def frozen(self) -> bool: return self._frozen def freeze(self) -> None: if self._frozen: return self.pre_freeze() self._frozen = True for subapp in self._subapps: subapp.freeze() @property def debug(self) -> bool: warnings.warn( "debug property is deprecated since 4.0 and scheduled for removal in 5.0", DeprecationWarning, stacklevel=2, ) return asyncio.get_event_loop().get_debug() def _reg_subapp_signals(self, subapp: "Application") -> None: def reg_handler(signame: str) -> None: subsig = getattr(subapp, signame) async def handler(app: "Application") -> None: await subsig.send(subapp) appsig = getattr(self, signame) appsig.append(handler) reg_handler("on_startup") reg_handler("on_shutdown") reg_handler("on_cleanup") def add_subapp(self, prefix: str, subapp: "Application") -> PrefixedSubAppResource: if not isinstance(prefix, str): raise TypeError("Prefix must be str") prefix = prefix.rstrip("/") if not prefix: raise ValueError("Prefix cannot be empty") factory = partial(PrefixedSubAppResource, prefix, subapp) return self._add_subapp(factory, subapp) def _add_subapp( self, resource_factory: Callable[[], _Resource], subapp: "Application" ) -> _Resource: if self.frozen: raise RuntimeError("Cannot add sub application to frozen application") if subapp.frozen: raise RuntimeError("Cannot add frozen application") resource = resource_factory() self.router.register_resource(resource) self._reg_subapp_signals(subapp) self._subapps.append(subapp) subapp.pre_freeze() return resource def add_domain(self, domain: str, subapp: "Application") -> MatchedSubAppResource: if not isinstance(domain, str): raise TypeError("Domain must be str") elif "*" in domain: rule: Domain = MaskDomain(domain) else: rule = Domain(domain) factory = partial(MatchedSubAppResource, rule, subapp) return self._add_subapp(factory, subapp) def add_routes(self, routes: Iterable[AbstractRouteDef]) -> list[AbstractRoute]: return self.router.add_routes(routes) @property def on_response_prepare(self) -> _RespPrepareSignal: return self._on_response_prepare @property def on_startup(self) -> _AppSignal: return self._on_startup @property def on_shutdown(self) -> _AppSignal: return self._on_shutdown @property def on_cleanup(self) -> _AppSignal: return self._on_cleanup @property def cleanup_ctx(self) -> "CleanupContext": return self._cleanup_ctx @property def router(self) -> UrlDispatcher: return self._router @property def middlewares(self) -> _Middlewares: return self._middlewares async def startup(self) -> None: """Causes on_startup signal Should be called in the event loop along with the request handler. """ await self.on_startup.send(self) async def shutdown(self) -> None: """Causes on_shutdown signal Should be called before cleanup() """ await self.on_shutdown.send(self) async def cleanup(self) -> None: """Causes on_cleanup signal Should be called after shutdown() """ if self.on_cleanup.frozen: await self.on_cleanup.send(self) else: # If an exception occurs in startup, ensure cleanup contexts are completed. await self._cleanup_ctx._on_cleanup(self) def _prepare_middleware(self) -> Iterator[Middleware]: yield from reversed(self._middlewares) yield _fix_request_current_app(self) async def _handle(self, request: Request) -> StreamResponse: match_info = await self._router.resolve(request) match_info.add_app(self) match_info.freeze() request._match_info = match_info if request.headers.get(hdrs.EXPECT): resp = await match_info.expect_handler(request) await request.writer.drain() if resp is not None: return resp handler = match_info.handler if self._run_middlewares: # If its a SystemRoute, don't cache building the middlewares since # they are constructed for every MatchInfoError as a new handler # is made each time. if isinstance(match_info.route, SystemRoute): handler = _build_middlewares(handler, match_info.apps) else: handler = _cached_build_middleware(handler, match_info.apps) return await handler(request) def __call__(self) -> "Application": """gunicorn compatibility""" return self def __repr__(self) -> str: return f"" def __bool__(self) -> bool: return True class CleanupError(RuntimeError): @property def exceptions(self) -> list[BaseException]: return cast(list[BaseException], self.args[1]) _CleanupContextCallable = ( Callable[[Application], AbstractAsyncContextManager[None]] | Callable[[Application], AsyncIterator[None]] ) class CleanupContext(FrozenList[_CleanupContextCallable]): def __init__(self) -> None: super().__init__() self._exits: list[AbstractAsyncContextManager[None]] = [] async def _on_startup(self, app: Application) -> None: for cb in self: ctx = cb(app) if not isinstance(ctx, AbstractAsyncContextManager): ctx = asynccontextmanager(cb)(app) # type: ignore[arg-type] await ctx.__aenter__() self._exits.append(ctx) async def _on_cleanup(self, app: Application) -> None: errors = [] for it in reversed(self._exits): try: await it.__aexit__(None, None, None) except (Exception, asyncio.CancelledError) as exc: errors.append(exc) if errors: if len(errors) == 1: raise errors[0] else: raise CleanupError("Multiple errors on cleanup stage", errors) ================================================ FILE: aiohttp/web_exceptions.py ================================================ import warnings from collections.abc import Iterable from http import HTTPStatus from typing import Any from multidict import CIMultiDict from yarl import URL from . import hdrs from .helpers import CookieMixin from .typedefs import LooseHeaders, StrOrURL __all__ = ( "HTTPException", "HTTPError", "HTTPRedirection", "HTTPSuccessful", "HTTPOk", "HTTPCreated", "HTTPAccepted", "HTTPNonAuthoritativeInformation", "HTTPNoContent", "HTTPResetContent", "HTTPPartialContent", "HTTPMove", "HTTPMultipleChoices", "HTTPMovedPermanently", "HTTPFound", "HTTPSeeOther", "HTTPNotModified", "HTTPUseProxy", "HTTPTemporaryRedirect", "HTTPPermanentRedirect", "HTTPClientError", "HTTPBadRequest", "HTTPUnauthorized", "HTTPPaymentRequired", "HTTPForbidden", "HTTPNotFound", "HTTPMethodNotAllowed", "HTTPNotAcceptable", "HTTPProxyAuthenticationRequired", "HTTPRequestTimeout", "HTTPConflict", "HTTPGone", "HTTPLengthRequired", "HTTPPreconditionFailed", "HTTPRequestEntityTooLarge", "HTTPRequestURITooLong", "HTTPUnsupportedMediaType", "HTTPRequestRangeNotSatisfiable", "HTTPExpectationFailed", "HTTPMisdirectedRequest", "HTTPUnprocessableEntity", "HTTPFailedDependency", "HTTPUpgradeRequired", "HTTPPreconditionRequired", "HTTPTooManyRequests", "HTTPRequestHeaderFieldsTooLarge", "HTTPUnavailableForLegalReasons", "HTTPServerError", "HTTPInternalServerError", "HTTPNotImplemented", "HTTPBadGateway", "HTTPServiceUnavailable", "HTTPGatewayTimeout", "HTTPVersionNotSupported", "HTTPVariantAlsoNegotiates", "HTTPInsufficientStorage", "HTTPNotExtended", "HTTPNetworkAuthenticationRequired", ) class NotAppKeyWarning(UserWarning): """Warning when not using AppKey in Application.""" ############################################################ # HTTP Exceptions ############################################################ class HTTPException(CookieMixin, Exception): # You should set in subclasses: # status = 200 status_code = -1 empty_body = False default_reason = "" # Initialized at the end of the module def __init__( self, *, headers: LooseHeaders | None = None, reason: str | None = None, text: str | None = None, content_type: str | None = None, ) -> None: if reason is None: reason = self.default_reason elif "\r" in reason or "\n" in reason: raise ValueError("Reason cannot contain \\r or \\n") if text is None: if not self.empty_body: text = f"{self.status_code}: {reason}" else: if self.empty_body: warnings.warn( f"text argument is deprecated for HTTP status {self.status_code} " "since 4.0 and scheduled for removal in 5.0 (#3462)," "the response should be provided without a body", DeprecationWarning, stacklevel=2, ) if headers is not None: real_headers = CIMultiDict(headers) else: real_headers = CIMultiDict() if content_type is not None: if not text: warnings.warn( "content_type without text is deprecated " "since 4.0 and scheduled for removal in 5.0 " "(#3462)", DeprecationWarning, stacklevel=2, ) real_headers[hdrs.CONTENT_TYPE] = content_type elif hdrs.CONTENT_TYPE not in real_headers and text: real_headers[hdrs.CONTENT_TYPE] = "text/plain" self._reason = reason self._text = text self._headers = real_headers self.args = () def __bool__(self) -> bool: return True @property def status(self) -> int: return self.status_code @property def reason(self) -> str: return self._reason @property def text(self) -> str | None: return self._text @property def headers(self) -> "CIMultiDict[str]": return self._headers def __str__(self) -> str: return self.reason def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self.reason}>" __reduce__ = object.__reduce__ def __getnewargs__(self) -> tuple[Any, ...]: return self.args class HTTPError(HTTPException): """Base class for exceptions with status codes in the 400s and 500s.""" class HTTPRedirection(HTTPException): """Base class for exceptions with status codes in the 300s.""" class HTTPSuccessful(HTTPException): """Base class for exceptions with status codes in the 200s.""" class HTTPOk(HTTPSuccessful): status_code = 200 class HTTPCreated(HTTPSuccessful): status_code = 201 class HTTPAccepted(HTTPSuccessful): status_code = 202 class HTTPNonAuthoritativeInformation(HTTPSuccessful): status_code = 203 class HTTPNoContent(HTTPSuccessful): status_code = 204 empty_body = True class HTTPResetContent(HTTPSuccessful): status_code = 205 empty_body = True class HTTPPartialContent(HTTPSuccessful): status_code = 206 ############################################################ # 3xx redirection ############################################################ class HTTPMove(HTTPRedirection): def __init__( self, location: StrOrURL, *, headers: LooseHeaders | None = None, reason: str | None = None, text: str | None = None, content_type: str | None = None, ) -> None: if not location: raise ValueError("HTTP redirects need a location to redirect to.") super().__init__( headers=headers, reason=reason, text=text, content_type=content_type ) self._location = URL(location) self.headers["Location"] = str(self.location) @property def location(self) -> URL: return self._location class HTTPMultipleChoices(HTTPMove): status_code = 300 class HTTPMovedPermanently(HTTPMove): status_code = 301 class HTTPFound(HTTPMove): status_code = 302 # This one is safe after a POST (the redirected location will be # retrieved with GET): class HTTPSeeOther(HTTPMove): status_code = 303 class HTTPNotModified(HTTPRedirection): # FIXME: this should include a date or etag header status_code = 304 empty_body = True class HTTPUseProxy(HTTPMove): # Not a move, but looks a little like one status_code = 305 class HTTPTemporaryRedirect(HTTPMove): status_code = 307 class HTTPPermanentRedirect(HTTPMove): status_code = 308 ############################################################ # 4xx client error ############################################################ class HTTPClientError(HTTPError): pass class HTTPBadRequest(HTTPClientError): status_code = 400 class HTTPUnauthorized(HTTPClientError): status_code = 401 class HTTPPaymentRequired(HTTPClientError): status_code = 402 class HTTPForbidden(HTTPClientError): status_code = 403 class HTTPNotFound(HTTPClientError): status_code = 404 class HTTPMethodNotAllowed(HTTPClientError): status_code = 405 def __init__( self, method: str, allowed_methods: Iterable[str], *, headers: LooseHeaders | None = None, reason: str | None = None, text: str | None = None, content_type: str | None = None, ) -> None: allow = ",".join(sorted(allowed_methods)) super().__init__( headers=headers, reason=reason, text=text, content_type=content_type ) self.headers["Allow"] = allow self._allowed: set[str] = set(allowed_methods) self._method = method @property def allowed_methods(self) -> set[str]: return self._allowed @property def method(self) -> str: return self._method class HTTPNotAcceptable(HTTPClientError): status_code = 406 class HTTPProxyAuthenticationRequired(HTTPClientError): status_code = 407 class HTTPRequestTimeout(HTTPClientError): status_code = 408 class HTTPConflict(HTTPClientError): status_code = 409 class HTTPGone(HTTPClientError): status_code = 410 class HTTPLengthRequired(HTTPClientError): status_code = 411 class HTTPPreconditionFailed(HTTPClientError): status_code = 412 class HTTPRequestEntityTooLarge(HTTPClientError): status_code = 413 def __init__(self, max_size: int, actual_size: int, **kwargs: Any) -> None: kwargs.setdefault( "text", f"Maximum request body size {max_size} exceeded, " f"actual body size {actual_size}", ) super().__init__(**kwargs) class HTTPRequestURITooLong(HTTPClientError): status_code = 414 class HTTPUnsupportedMediaType(HTTPClientError): status_code = 415 class HTTPRequestRangeNotSatisfiable(HTTPClientError): status_code = 416 class HTTPExpectationFailed(HTTPClientError): status_code = 417 class HTTPMisdirectedRequest(HTTPClientError): status_code = 421 class HTTPUnprocessableEntity(HTTPClientError): status_code = 422 class HTTPFailedDependency(HTTPClientError): status_code = 424 class HTTPUpgradeRequired(HTTPClientError): status_code = 426 class HTTPPreconditionRequired(HTTPClientError): status_code = 428 class HTTPTooManyRequests(HTTPClientError): status_code = 429 class HTTPRequestHeaderFieldsTooLarge(HTTPClientError): status_code = 431 class HTTPUnavailableForLegalReasons(HTTPClientError): status_code = 451 def __init__( self, link: StrOrURL | None, *, headers: LooseHeaders | None = None, reason: str | None = None, text: str | None = None, content_type: str | None = None, ) -> None: super().__init__( headers=headers, reason=reason, text=text, content_type=content_type ) self._link = None if link: self._link = URL(link) self.headers["Link"] = f'<{str(self._link)}>; rel="blocked-by"' @property def link(self) -> URL | None: return self._link ############################################################ # 5xx Server Error ############################################################ # Response status codes beginning with the digit "5" indicate cases in # which the server is aware that it has erred or is incapable of # performing the request. Except when responding to a HEAD request, the # server SHOULD include an entity containing an explanation of the error # situation, and whether it is a temporary or permanent condition. User # agents SHOULD display any included entity to the user. These response # codes are applicable to any request method. class HTTPServerError(HTTPError): pass class HTTPInternalServerError(HTTPServerError): status_code = 500 class HTTPNotImplemented(HTTPServerError): status_code = 501 class HTTPBadGateway(HTTPServerError): status_code = 502 class HTTPServiceUnavailable(HTTPServerError): status_code = 503 class HTTPGatewayTimeout(HTTPServerError): status_code = 504 class HTTPVersionNotSupported(HTTPServerError): status_code = 505 class HTTPVariantAlsoNegotiates(HTTPServerError): status_code = 506 class HTTPInsufficientStorage(HTTPServerError): status_code = 507 class HTTPNotExtended(HTTPServerError): status_code = 510 class HTTPNetworkAuthenticationRequired(HTTPServerError): status_code = 511 def _initialize_default_reason() -> None: for obj in globals().values(): if isinstance(obj, type) and issubclass(obj, HTTPException): if obj.status_code >= 0: try: status = HTTPStatus(obj.status_code) obj.default_reason = status.phrase except ValueError: pass _initialize_default_reason() del _initialize_default_reason ================================================ FILE: aiohttp/web_fileresponse.py ================================================ import asyncio import io import os import pathlib import sys from collections.abc import Awaitable, Callable from contextlib import suppress from enum import Enum, auto from mimetypes import MimeTypes from stat import S_ISREG from types import MappingProxyType from typing import IO, TYPE_CHECKING, Any, Final, Optional from . import hdrs from .abc import AbstractStreamWriter from .helpers import ETAG_ANY, ETag, must_be_empty_body from .typedefs import LooseHeaders, PathLike from .web_exceptions import ( HTTPForbidden, HTTPNotFound, HTTPNotModified, HTTPPartialContent, HTTPPreconditionFailed, HTTPRequestRangeNotSatisfiable, ) from .web_response import StreamResponse __all__ = ("FileResponse",) if TYPE_CHECKING: from .web_request import BaseRequest _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE")) CONTENT_TYPES: Final[MimeTypes] = MimeTypes() # File extension to IANA encodings map that will be checked in the order defined. ENCODING_EXTENSIONS = MappingProxyType( {ext: CONTENT_TYPES.encodings_map[ext] for ext in (".br", ".gz")} ) FALLBACK_CONTENT_TYPE = "application/octet-stream" # Provide additional MIME type/extension pairs to be recognized. # https://en.wikipedia.org/wiki/List_of_archive_formats#Compression_only ADDITIONAL_CONTENT_TYPES = MappingProxyType( { "application/gzip": ".gz", "application/x-brotli": ".br", "application/x-bzip2": ".bz2", "application/x-compress": ".Z", "application/x-xz": ".xz", } ) class _FileResponseResult(Enum): """The result of the file response.""" SEND_FILE = auto() # Ie a regular file to send NOT_ACCEPTABLE = auto() # Ie a socket, or non-regular file PRE_CONDITION_FAILED = auto() # Ie If-Match or If-None-Match failed NOT_MODIFIED = auto() # 304 Not Modified # Add custom pairs and clear the encodings map so guess_type ignores them. CONTENT_TYPES.encodings_map.clear() for content_type, extension in ADDITIONAL_CONTENT_TYPES.items(): CONTENT_TYPES.add_type(content_type, extension) _CLOSE_FUTURES: set[asyncio.Future[None]] = set() class FileResponse(StreamResponse): """A response object can be used to send files.""" def __init__( self, path: PathLike, chunk_size: int = 256 * 1024, status: int = 200, reason: str | None = None, headers: LooseHeaders | None = None, ) -> None: super().__init__(status=status, reason=reason, headers=headers) self._path = pathlib.Path(path) self._chunk_size = chunk_size def _seek_and_read(self, fobj: IO[Any], offset: int, chunk_size: int) -> bytes: fobj.seek(offset) return fobj.read(chunk_size) # type: ignore[no-any-return] async def _sendfile_fallback( self, writer: AbstractStreamWriter, fobj: IO[Any], offset: int, count: int ) -> AbstractStreamWriter: # To keep memory usage low,fobj is transferred in chunks # controlled by the constructor's chunk_size argument. chunk_size = self._chunk_size loop = asyncio.get_event_loop() chunk = await loop.run_in_executor( None, self._seek_and_read, fobj, offset, min(chunk_size, count) ) while chunk: await writer.write(chunk) count = count - len(chunk) if count <= 0: break chunk = await loop.run_in_executor(None, fobj.read, min(chunk_size, count)) await writer.drain() return writer async def _sendfile( self, request: "BaseRequest", fobj: IO[Any], offset: int, count: int ) -> AbstractStreamWriter: writer = await super().prepare(request) assert writer is not None if NOSENDFILE or self.compression: return await self._sendfile_fallback(writer, fobj, offset, count) loop = request._loop transport = request.transport assert transport is not None try: await loop.sendfile(transport, fobj, offset, count) except NotImplementedError: return await self._sendfile_fallback(writer, fobj, offset, count) await super().write_eof() return writer @staticmethod def _etag_match(etag_value: str, etags: tuple[ETag, ...], *, weak: bool) -> bool: if len(etags) == 1 and etags[0].value == ETAG_ANY: return True return any( etag.value == etag_value for etag in etags if weak or not etag.is_weak ) async def _not_modified( self, request: "BaseRequest", etag_value: str, last_modified: float ) -> AbstractStreamWriter | None: self.set_status(HTTPNotModified.status_code) self._length_check = False self.etag = etag_value self.last_modified = last_modified # Delete any Content-Length headers provided by user. HTTP 304 # should always have empty response body return await super().prepare(request) async def _precondition_failed( self, request: "BaseRequest" ) -> AbstractStreamWriter | None: self.set_status(HTTPPreconditionFailed.status_code) self.content_length = 0 return await super().prepare(request) def _make_response( self, request: "BaseRequest", accept_encoding: str ) -> tuple[ _FileResponseResult, io.BufferedReader | None, os.stat_result, str | None ]: """Return the response result, io object, stat result, and encoding. If an uncompressed file is returned, the encoding is set to :py:data:`None`. This method should be called from a thread executor since it calls os.stat which may block. """ file_path, st, file_encoding = self._get_file_path_stat_encoding( accept_encoding ) if not file_path: return _FileResponseResult.NOT_ACCEPTABLE, None, st, None etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2 if (ifmatch := request.if_match) is not None and not self._etag_match( etag_value, ifmatch, weak=False ): return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding if ( (unmodsince := request.if_unmodified_since) is not None and ifmatch is None and st.st_mtime > unmodsince.timestamp() ): return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding # https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2 if (ifnonematch := request.if_none_match) is not None and self._etag_match( etag_value, ifnonematch, weak=True ): return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding if ( (modsince := request.if_modified_since) is not None and ifnonematch is None and st.st_mtime <= modsince.timestamp() ): return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding fobj = file_path.open("rb") with suppress(OSError): # fstat() may not be available on all platforms # Once we open the file, we want the fstat() to ensure # the file has not changed between the first stat() # and the open(). st = os.stat(fobj.fileno()) return _FileResponseResult.SEND_FILE, fobj, st, file_encoding def _get_file_path_stat_encoding( self, accept_encoding: str ) -> tuple[pathlib.Path | None, os.stat_result, str | None]: file_path = self._path for file_extension, file_encoding in ENCODING_EXTENSIONS.items(): if file_encoding not in accept_encoding: continue compressed_path = file_path.with_suffix(file_path.suffix + file_extension) with suppress(OSError): # Do not follow symlinks and ignore any non-regular files. st = compressed_path.lstat() if S_ISREG(st.st_mode): return compressed_path, st, file_encoding # Fallback to the uncompressed file st = file_path.stat() return file_path if S_ISREG(st.st_mode) else None, st, None async def prepare(self, request: "BaseRequest") -> AbstractStreamWriter | None: loop = asyncio.get_running_loop() # Encoding comparisons should be case-insensitive # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() try: response_result, fobj, st, file_encoding = await loop.run_in_executor( None, self._make_response, request, accept_encoding ) except PermissionError: self.set_status(HTTPForbidden.status_code) return await super().prepare(request) except OSError: # Most likely to be FileNotFoundError or OSError for circular # symlinks in python >= 3.13, so respond with 404. self.set_status(HTTPNotFound.status_code) return await super().prepare(request) # Forbid special files like sockets, pipes, devices, etc. if response_result is _FileResponseResult.NOT_ACCEPTABLE: self.set_status(HTTPForbidden.status_code) return await super().prepare(request) if response_result is _FileResponseResult.PRE_CONDITION_FAILED: return await self._precondition_failed(request) if response_result is _FileResponseResult.NOT_MODIFIED: etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" last_modified = st.st_mtime return await self._not_modified(request, etag_value, last_modified) assert fobj is not None try: return await self._prepare_open_file(request, fobj, st, file_encoding) finally: # We do not await here because we do not want to wait # for the executor to finish before returning the response # so the connection can begin servicing another request # as soon as possible. close_future = loop.run_in_executor(None, fobj.close) # Hold a strong reference to the future to prevent it from being # garbage collected before it completes. _CLOSE_FUTURES.add(close_future) close_future.add_done_callback(_CLOSE_FUTURES.remove) async def _prepare_open_file( self, request: "BaseRequest", fobj: io.BufferedReader, st: os.stat_result, file_encoding: str | None, ) -> AbstractStreamWriter | None: status = self._status file_size: int = st.st_size file_mtime: float = st.st_mtime count: int = file_size start: int | None = None if (ifrange := request.if_range) is None or file_mtime <= ifrange.timestamp(): # If-Range header check: # condition = cached date >= last modification date # return 206 if True else 200. # if False: # Range header would not be processed, return 200 # if True but Range header missing # return 200 try: rng = request.http_range start = rng.start end: int | None = rng.stop except ValueError: # https://tools.ietf.org/html/rfc7233: # A server generating a 416 (Range Not Satisfiable) response to # a byte-range request SHOULD send a Content-Range header field # with an unsatisfied-range value. # The complete-length in a 416 response indicates the current # length of the selected representation. # # Will do the same below. Many servers ignore this and do not # send a Content-Range header with HTTP 416 self._headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}" self.set_status(HTTPRequestRangeNotSatisfiable.status_code) return await super().prepare(request) # If a range request has been made, convert start, end slice # notation into file pointer offset and count if start is not None: if start < 0 and end is None: # return tail of file start += file_size if start < 0: # if Range:bytes=-1000 in request header but file size # is only 200, there would be trouble without this start = 0 count = file_size - start else: # rfc7233:If the last-byte-pos value is # absent, or if the value is greater than or equal to # the current length of the representation data, # the byte range is interpreted as the remainder # of the representation (i.e., the server replaces the # value of last-byte-pos with a value that is one less than # the current length of the selected representation). count = ( min(end if end is not None else file_size, file_size) - start ) if start >= file_size: # HTTP 416 should be returned in this case. # # According to https://tools.ietf.org/html/rfc7233: # If a valid byte-range-set includes at least one # byte-range-spec with a first-byte-pos that is less than # the current length of the representation, or at least one # suffix-byte-range-spec with a non-zero suffix-length, # then the byte-range-set is satisfiable. Otherwise, the # byte-range-set is unsatisfiable. self._headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}" self.set_status(HTTPRequestRangeNotSatisfiable.status_code) return await super().prepare(request) status = HTTPPartialContent.status_code # Even though you are sending the whole file, you should still # return a HTTP 206 for a Range request. self.set_status(status) # If the Content-Type header is not already set, guess it based on the # extension of the request path. The encoding returned by guess_type # can be ignored since the map was cleared above. if hdrs.CONTENT_TYPE not in self._headers: if sys.version_info >= (3, 13): guesser = CONTENT_TYPES.guess_file_type else: guesser = CONTENT_TYPES.guess_type self.content_type = guesser(self._path)[0] or FALLBACK_CONTENT_TYPE if file_encoding: self._headers[hdrs.CONTENT_ENCODING] = file_encoding self._headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING # Disable compression if we are already sending # a compressed file since we don't want to double # compress. self._compression = False self.etag = f"{st.st_mtime_ns:x}-{st.st_size:x}" self.last_modified = file_mtime self.content_length = count self._headers[hdrs.ACCEPT_RANGES] = "bytes" if status == HTTPPartialContent.status_code: real_start = start assert real_start is not None self._headers[hdrs.CONTENT_RANGE] = ( f"bytes {real_start}-{real_start + count - 1}/{file_size}" ) # If we are sending 0 bytes calling sendfile() will throw a ValueError if count == 0 or must_be_empty_body(request.method, status): return await super().prepare(request) # be aware that start could be None or int=0 here. offset = start or 0 return await self._sendfile(request, fobj, offset, count) ================================================ FILE: aiohttp/web_log.py ================================================ import datetime import functools import logging import os import re import time as time_mod from collections import namedtuple from collections.abc import Iterable from typing import Callable, ClassVar from .abc import AbstractAccessLogger from .web_request import BaseRequest from .web_response import StreamResponse KeyMethod = namedtuple("KeyMethod", "key method") class AccessLogger(AbstractAccessLogger): """Helper object to log access. Usage: log = logging.getLogger("spam") log_format = "%a %{User-Agent}i" access_logger = AccessLogger(log, log_format) access_logger.log(request, response, time) Format: %% The percent sign %a Remote IP-address (IP-address of proxy if using reverse proxy) %t Time when the request was started to process %P The process ID of the child that serviced the request %r First line of request %s Response status code %b Size of response in bytes, including HTTP headers %T Time taken to serve the request, in seconds %Tf Time taken to serve the request, in seconds with floating fraction in .06f format %D Time taken to serve the request, in microseconds %{FOO}i request.headers['FOO'] %{FOO}o response.headers['FOO'] %{FOO}e os.environ['FOO'] """ LOG_FORMAT_MAP = { "a": "remote_address", "t": "request_start_time", "P": "process_id", "r": "first_request_line", "s": "response_status", "b": "response_size", "T": "request_time", "Tf": "request_time_frac", "D": "request_time_micro", "i": "request_header", "o": "response_header", } LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' FORMAT_RE = re.compile(r"%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)") CLEANUP_RE = re.compile(r"(%[^s])") _FORMAT_CACHE: dict[str, tuple[str, list[KeyMethod]]] = {} _cached_tz: ClassVar[datetime.timezone | None] = None _cached_tz_expires: ClassVar[float] = 0.0 def __init__(self, logger: logging.Logger, log_format: str = LOG_FORMAT) -> None: """Initialise the logger. logger is a logger object to be used for logging. log_format is a string with apache compatible log format description. """ super().__init__(logger, log_format=log_format) _compiled_format = AccessLogger._FORMAT_CACHE.get(log_format) if not _compiled_format: _compiled_format = self.compile_format(log_format) AccessLogger._FORMAT_CACHE[log_format] = _compiled_format self._log_format, self._methods = _compiled_format def compile_format(self, log_format: str) -> tuple[str, list[KeyMethod]]: """Translate log_format into form usable by modulo formatting All known atoms will be replaced with %s Also methods for formatting of those atoms will be added to _methods in appropriate order For example we have log_format = "%a %t" This format will be translated to "%s %s" Also contents of _methods will be [self._format_a, self._format_t] These method will be called and results will be passed to translated string format. Each _format_* method receive 'args' which is list of arguments given to self.log Exceptions are _format_e, _format_i and _format_o methods which also receive key name (by functools.partial) """ # list of (key, method) tuples, we don't use an OrderedDict as users # can repeat the same key more than once methods = list() for atom in self.FORMAT_RE.findall(log_format): if atom[1] == "": format_key1 = self.LOG_FORMAT_MAP[atom[0]] m = getattr(AccessLogger, "_format_%s" % atom[0]) key_method = KeyMethod(format_key1, m) else: format_key2 = (self.LOG_FORMAT_MAP[atom[2]], atom[1]) m = getattr(AccessLogger, "_format_%s" % atom[2]) key_method = KeyMethod(format_key2, functools.partial(m, atom[1])) methods.append(key_method) log_format = self.FORMAT_RE.sub(r"%s", log_format) log_format = self.CLEANUP_RE.sub(r"%\1", log_format) return log_format, methods @staticmethod def _format_i( key: str, request: BaseRequest, response: StreamResponse, time: float ) -> str: # suboptimal, make istr(key) once return request.headers.get(key, "-") @staticmethod def _format_o( key: str, request: BaseRequest, response: StreamResponse, time: float ) -> str: # suboptimal, make istr(key) once return response.headers.get(key, "-") @staticmethod def _format_a(request: BaseRequest, response: StreamResponse, time: float) -> str: ip = request.remote return ip if ip is not None else "-" @classmethod def _get_local_time(cls) -> datetime.datetime: if cls._cached_tz is None or time_mod.time() >= cls._cached_tz_expires: gmtoff = time_mod.localtime().tm_gmtoff cls._cached_tz = tz = datetime.timezone(datetime.timedelta(seconds=gmtoff)) now = datetime.datetime.now(tz) # Expire at every 30 mins, as any DST change should occur at 0/30 mins past. d = now + datetime.timedelta(minutes=30) d = d.replace(minute=30 if d.minute >= 30 else 0, second=0, microsecond=0) cls._cached_tz_expires = d.timestamp() return now return datetime.datetime.now(cls._cached_tz) @staticmethod def _format_t(request: BaseRequest, response: StreamResponse, time: float) -> str: now = AccessLogger._get_local_time() start_time = now - datetime.timedelta(seconds=time) return start_time.strftime("[%d/%b/%Y:%H:%M:%S %z]") @staticmethod def _format_P(request: BaseRequest, response: StreamResponse, time: float) -> str: return "<%s>" % os.getpid() @staticmethod def _format_r(request: BaseRequest, response: StreamResponse, time: float) -> str: return f"{request.method} {request.path_qs} HTTP/{request.version.major}.{request.version.minor}" @staticmethod def _format_s(request: BaseRequest, response: StreamResponse, time: float) -> int: return response.status @staticmethod def _format_b(request: BaseRequest, response: StreamResponse, time: float) -> int: return response.body_length @staticmethod def _format_T(request: BaseRequest, response: StreamResponse, time: float) -> str: return str(round(time)) @staticmethod def _format_Tf(request: BaseRequest, response: StreamResponse, time: float) -> str: return "%06f" % time @staticmethod def _format_D(request: BaseRequest, response: StreamResponse, time: float) -> str: return str(round(time * 1000000)) def _format_line( self, request: BaseRequest, response: StreamResponse, time: float ) -> Iterable[tuple[str, Callable[[BaseRequest, StreamResponse, float], str]]]: return [(key, method(request, response, time)) for key, method in self._methods] @property def enabled(self) -> bool: """Check if logger is enabled.""" # Avoid formatting the log line if it will not be emitted. return self.logger.isEnabledFor(logging.INFO) def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None: try: fmt_info = self._format_line(request, response, time) values = list() extra = dict() for key, value in fmt_info: values.append(value) if key.__class__ is str: extra[key] = value else: k1, k2 = key # type: ignore[misc] dct = extra.get(k1, {}) # type: ignore[var-annotated,has-type] dct[k2] = value # type: ignore[index,has-type] extra[k1] = dct # type: ignore[has-type,assignment] self.logger.info(self._log_format % tuple(values), extra=extra) except Exception: self.logger.exception("Error in logging") ================================================ FILE: aiohttp/web_middlewares.py ================================================ import re import warnings from typing import TYPE_CHECKING, TypeVar from .typedefs import Handler, Middleware from .web_exceptions import HTTPMove, HTTPPermanentRedirect from .web_request import Request from .web_response import StreamResponse from .web_urldispatcher import SystemRoute __all__ = ( "middleware", "normalize_path_middleware", ) if TYPE_CHECKING: from .web_app import Application _Func = TypeVar("_Func") async def _check_request_resolves(request: Request, path: str) -> tuple[bool, Request]: alt_request = request.clone(rel_url=path) match_info = await request.app.router.resolve(alt_request) alt_request._match_info = match_info if match_info.http_exception is None: return True, alt_request return False, request def middleware(f: _Func) -> _Func: warnings.warn( "Middleware decorator is deprecated since 4.0 " "and its behaviour is default, " "you can simply remove this decorator.", DeprecationWarning, stacklevel=2, ) return f def normalize_path_middleware( *, append_slash: bool = True, remove_slash: bool = False, merge_slashes: bool = True, redirect_class: type[HTTPMove] = HTTPPermanentRedirect, ) -> Middleware: """Factory for producing a middleware that normalizes the path of a request. Normalizing means: - Add or remove a trailing slash to the path. - Double slashes are replaced by one. The middleware returns as soon as it finds a path that resolves correctly. The order if both merge and append/remove are enabled is 1) merge slashes 2) append/remove slash 3) both merge slashes and append/remove slash. If the path resolves with at least one of those conditions, it will redirect to the new path. Only one of `append_slash` and `remove_slash` can be enabled. If both are `True` the factory will raise an assertion error If `append_slash` is `True` the middleware will append a slash when needed. If a resource is defined with trailing slash and the request comes without it, it will append it automatically. If `remove_slash` is `True`, `append_slash` must be `False`. When enabled the middleware will remove trailing slashes and redirect if the resource is defined If merge_slashes is True, merge multiple consecutive slashes in the path into one. """ correct_configuration = not (append_slash and remove_slash) assert correct_configuration, "Cannot both remove and append slash" async def impl(request: Request, handler: Handler) -> StreamResponse: if isinstance(request.match_info.route, SystemRoute): paths_to_check = [] if "?" in request.raw_path: path, query = request.raw_path.split("?", 1) query = "?" + query else: query = "" path = request.raw_path if merge_slashes: paths_to_check.append(re.sub("//+", "/", path)) if append_slash and not request.path.endswith("/"): paths_to_check.append(path + "/") if remove_slash and request.path.endswith("/"): paths_to_check.append(path[:-1]) if merge_slashes and append_slash: paths_to_check.append(re.sub("//+", "/", path + "/")) if merge_slashes and remove_slash and path.endswith("/"): merged_slashes = re.sub("//+", "/", path) paths_to_check.append(merged_slashes[:-1]) for path in paths_to_check: path = re.sub("^//+", "/", path) # SECURITY: GHSA-v6wp-4m6f-gcjg resolves, request = await _check_request_resolves(request, path) if resolves: raise redirect_class(request.raw_path + query) return await handler(request) return impl def _fix_request_current_app(app: "Application") -> Middleware: async def impl(request: Request, handler: Handler) -> StreamResponse: match_info = request.match_info prev = match_info.current_app match_info.current_app = app try: return await handler(request) finally: match_info.current_app = prev return impl ================================================ FILE: aiohttp/web_protocol.py ================================================ import asyncio import asyncio.streams import sys import traceback from collections import deque from collections.abc import Awaitable, Callable, Sequence from contextlib import suppress from html import escape as html_escape from http import HTTPStatus from logging import Logger from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast import yarl from propcache import under_cached_property from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter from .base_protocol import BaseProtocol from .helpers import ceil_timeout, frozen_dataclass_decorator from .http import ( HttpProcessingError, HttpRequestParser, HttpVersion10, RawRequestMessage, StreamWriter, ) from .http_exceptions import BadHttpMethod from .log import access_logger, server_logger from .streams import EMPTY_PAYLOAD, StreamReader from .tcp_helpers import tcp_keepalive from .web_exceptions import HTTPException, HTTPInternalServerError from .web_log import AccessLogger from .web_request import BaseRequest from .web_response import Response, StreamResponse __all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") if TYPE_CHECKING: import ssl from .web_server import Server _Request = TypeVar("_Request", bound=BaseRequest) _RequestFactory = Callable[ [ RawRequestMessage, StreamReader, "RequestHandler[_Request]", AbstractStreamWriter, "asyncio.Task[None]", ], _Request, ] _RequestHandler = Callable[[_Request], Awaitable[StreamResponse]] _AnyAbstractAccessLogger = Union[ type[AbstractAsyncAccessLogger], type[AbstractAccessLogger], ] ERROR = RawRequestMessage( "UNKNOWN", "/", HttpVersion10, {}, # type: ignore[arg-type] {}, # type: ignore[arg-type] True, None, False, False, yarl.URL("/"), ) class RequestPayloadError(Exception): """Payload parsing error.""" class PayloadAccessError(Exception): """Payload was accessed after response was sent.""" _PAYLOAD_ACCESS_ERROR = PayloadAccessError() class AccessLoggerWrapper(AbstractAsyncAccessLogger): """Wrap an AbstractAccessLogger so it behaves like an AbstractAsyncAccessLogger.""" __slots__ = ("access_logger", "_loop") def __init__( self, access_logger: AbstractAccessLogger, loop: asyncio.AbstractEventLoop ) -> None: self.access_logger = access_logger self._loop = loop super().__init__() async def log( self, request: BaseRequest, response: StreamResponse, request_start: float ) -> None: self.access_logger.log(request, response, self._loop.time() - request_start) @property def enabled(self) -> bool: """Check if logger is enabled.""" return self.access_logger.enabled @frozen_dataclass_decorator class _ErrInfo: status: int exc: BaseException message: str _MsgType = tuple[RawRequestMessage | _ErrInfo, StreamReader] class RequestHandler(BaseProtocol, Generic[_Request]): """HTTP protocol implementation. RequestHandler handles incoming HTTP request. It reads request line, request headers and request payload and calls handle_request() method. By default it always returns with 404 response. RequestHandler handles errors in incoming request, like bad status line, bad headers or incomplete payload. If any error occurs, connection gets closed. keepalive_timeout -- number of seconds before closing keep-alive connection tcp_keepalive -- TCP keep-alive is on, default is on logger -- custom logger object access_log_class -- custom class for access_logger access_log -- custom logging object access_log_format -- access log format string loop -- Optional event loop max_line_size -- Optional maximum header line size max_field_size -- Optional maximum header field size timeout_ceil_threshold -- Optional value to specify threshold to ceil() timeout values """ __slots__ = ( "max_field_size", "max_headers", "max_line_size", "_request_count", "_keepalive", "_manager", "_request_handler", "_request_factory", "_tcp_keepalive", "_next_keepalive_close_time", "_keepalive_handle", "_keepalive_timeout", "_lingering_time", "_messages", "_message_tail", "_handler_waiter", "_waiter", "_task_handler", "_upgrade", "_payload_parser", "_data_received_cb", "_request_parser", "logger", "access_log", "access_logger", "_close", "_force_close", "_current_request", "_timeout_ceil_threshold", "_request_in_progress", "_logging_enabled", "_cache", ) def __init__( self, manager: "Server[_Request]", *, loop: asyncio.AbstractEventLoop, # Default should be high enough that it's likely longer than a reverse proxy. keepalive_timeout: float = 3630, tcp_keepalive: bool = True, logger: Logger = server_logger, access_log_class: _AnyAbstractAccessLogger = AccessLogger, access_log: Logger | None = access_logger, access_log_format: str = AccessLogger.LOG_FORMAT, max_line_size: int = 8190, max_headers: int = 128, max_field_size: int = 8190, lingering_time: float = 10.0, read_bufsize: int = 2**16, auto_decompress: bool = True, timeout_ceil_threshold: float = 5, ): super().__init__(loop) # _request_count is the number of requests processed with the same connection. self._request_count = 0 self._keepalive = False self._current_request: _Request | None = None self._manager: Server[_Request] | None = manager self._request_handler: _RequestHandler[_Request] | None = ( manager.request_handler ) self._request_factory: _RequestFactory[_Request] | None = ( manager.request_factory ) self.max_line_size = max_line_size self.max_headers = max_headers self.max_field_size = max_field_size self._tcp_keepalive = tcp_keepalive # placeholder to be replaced on keepalive timeout setup self._next_keepalive_close_time = 0.0 self._keepalive_handle: asyncio.Handle | None = None self._keepalive_timeout = keepalive_timeout self._lingering_time = float(lingering_time) self._messages: deque[_MsgType] = deque() self._message_tail = b"" self._data_received_cb: Callable[[], None] | None = None self._waiter: asyncio.Future[None] | None = None self._handler_waiter: asyncio.Future[None] | None = None self._task_handler: asyncio.Task[None] | None = None self._upgrade = False self._payload_parser: Any = None self._request_parser: HttpRequestParser | None = HttpRequestParser( self, loop, read_bufsize, max_line_size=max_line_size, max_field_size=max_field_size, max_headers=max_headers, payload_exception=RequestPayloadError, auto_decompress=auto_decompress, ) self._timeout_ceil_threshold: float = 5 try: self._timeout_ceil_threshold = float(timeout_ceil_threshold) except (TypeError, ValueError): pass self.logger = logger self.access_log = access_log if access_log: if issubclass(access_log_class, AbstractAsyncAccessLogger): self.access_logger: AbstractAsyncAccessLogger | None = ( access_log_class() ) else: access_logger = access_log_class(access_log, access_log_format) self.access_logger = AccessLoggerWrapper( access_logger, self._loop, ) self._logging_enabled = self.access_logger.enabled else: self.access_logger = None self._logging_enabled = False self._close = False self._force_close = False self._request_in_progress = False self._cache: dict[str, Any] = {} def __repr__(self) -> str: return "<{} {}>".format( self.__class__.__name__, "connected" if self.transport is not None else "disconnected", ) @under_cached_property def ssl_context(self) -> Optional["ssl.SSLContext"]: """Return SSLContext if available.""" return ( None if self.transport is None else self.transport.get_extra_info("sslcontext") ) @under_cached_property def peername( self, ) -> str | tuple[str, int, int, int] | tuple[str, int] | None: """Return peername if available.""" return ( None if self.transport is None else self.transport.get_extra_info("peername") ) @property def keepalive_timeout(self) -> float: return self._keepalive_timeout async def shutdown(self, timeout: float | None = 15.0) -> None: """Do worker process exit preparations. We need to clean up everything and stop accepting requests. It is especially important for keep-alive connections. """ self._force_close = True if self._keepalive_handle is not None: self._keepalive_handle.cancel() # Wait for graceful handler completion if self._request_in_progress: # The future is only created when we are shutting # down while the handler is still processing a request # to avoid creating a future for every request. self._handler_waiter = self._loop.create_future() try: async with ceil_timeout(timeout): await self._handler_waiter except (asyncio.CancelledError, asyncio.TimeoutError): self._handler_waiter = None if ( sys.version_info >= (3, 11) and (task := asyncio.current_task()) and task.cancelling() ): raise # Then cancel handler and wait try: async with ceil_timeout(timeout): if self._current_request is not None: self._current_request._cancel(asyncio.CancelledError()) if self._task_handler is not None and not self._task_handler.done(): await asyncio.shield(self._task_handler) except (asyncio.CancelledError, asyncio.TimeoutError): if ( sys.version_info >= (3, 11) and (task := asyncio.current_task()) and task.cancelling() ): raise # force-close non-idle handler if self._task_handler is not None: self._task_handler.cancel() self.force_close() def connection_made(self, transport: asyncio.BaseTransport) -> None: super().connection_made(transport) real_transport = cast(asyncio.Transport, transport) if self._tcp_keepalive: tcp_keepalive(real_transport) assert self._manager is not None self._manager.connection_made(self, real_transport) loop = self._loop if sys.version_info >= (3, 12): task = asyncio.Task(self.start(), loop=loop, eager_start=True) else: task = loop.create_task(self.start()) self._task_handler = task def connection_lost(self, exc: BaseException | None) -> None: if self._manager is None: return self._manager.connection_lost(self, exc) # Grab value before setting _manager to None. handler_cancellation = self._manager.handler_cancellation self.force_close() super().connection_lost(exc) self._manager = None self._request_factory = None self._request_handler = None self._request_parser = None if self._keepalive_handle is not None: self._keepalive_handle.cancel() if self._current_request is not None: if exc is None: exc = ConnectionResetError("Connection lost") self._current_request._cancel(exc) if handler_cancellation and self._task_handler is not None: self._task_handler.cancel() self._task_handler = None if self._payload_parser is not None: self._payload_parser.feed_eof() self._payload_parser = None def set_parser( self, parser: Any, data_received_cb: Callable[[], None] | None = None ) -> None: # Actual type is WebReader assert self._payload_parser is None self._payload_parser = parser self._data_received_cb = data_received_cb if self._message_tail: self._payload_parser.feed_data(self._message_tail) self._message_tail = b"" def eof_received(self) -> None: pass def data_received(self, data: bytes) -> None: if self._force_close or self._close: return # parse http messages messages: Sequence[_MsgType] if self._payload_parser is None and not self._upgrade: assert self._request_parser is not None try: messages, upgraded, tail = self._request_parser.feed_data(data) except HttpProcessingError as exc: messages = [ (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD) ] upgraded = False tail = b"" for msg, payload in messages or (): self._request_count += 1 self._messages.append((msg, payload)) waiter = self._waiter if messages and waiter is not None and not waiter.done(): # don't set result twice waiter.set_result(None) self._upgrade = upgraded if upgraded and tail: self._message_tail = tail # no parser, just store elif self._payload_parser is None and self._upgrade and data: self._message_tail += data # feed payload elif data: if self._data_received_cb is not None: self._data_received_cb() eof, tail = self._payload_parser.feed_data(data) if eof: self.close() def keep_alive(self, val: bool) -> None: """Set keep-alive connection mode. :param bool val: new state. """ self._keepalive = val if self._keepalive_handle: self._keepalive_handle.cancel() self._keepalive_handle = None def close(self) -> None: """Close connection. Stop accepting new pipelining messages and close connection when handlers done processing messages. """ self._close = True if self._waiter: self._waiter.cancel() def force_close(self) -> None: """Forcefully close connection.""" self._force_close = True if self._waiter: self._waiter.cancel() if self.transport is not None: self.transport.close() self.transport = None async def log_access( self, request: BaseRequest, response: StreamResponse, request_start: float | None, ) -> None: if self._logging_enabled and self.access_logger is not None: if TYPE_CHECKING: assert request_start is not None await self.access_logger.log(request, response, request_start) def log_debug(self, *args: Any, **kw: Any) -> None: if self._loop.get_debug(): self.logger.debug(*args, **kw) def log_exception(self, *args: Any, **kw: Any) -> None: self.logger.exception(*args, **kw) def _process_keepalive(self) -> None: self._keepalive_handle = None if self._force_close or not self._keepalive: return loop = self._loop now = loop.time() close_time = self._next_keepalive_close_time if now < close_time: # Keep alive close check fired too early, reschedule self._keepalive_handle = loop.call_at(close_time, self._process_keepalive) return # handler in idle state if self._waiter and not self._waiter.done(): self.force_close() async def _handle_request( self, request: _Request, start_time: float | None, request_handler: Callable[[_Request], Awaitable[StreamResponse]], ) -> tuple[StreamResponse, bool]: self._request_in_progress = True try: try: self._current_request = request resp = await request_handler(request) finally: self._current_request = None except HTTPException as exc: resp = Response( status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers ) resp._cookies = exc._cookies resp, reset = await self.finish_response(request, resp, start_time) except asyncio.CancelledError: raise except asyncio.TimeoutError as exc: self.log_debug("Request handler timed out.", exc_info=exc) resp = self.handle_error(request, 504) resp, reset = await self.finish_response(request, resp, start_time) except Exception as exc: resp = self.handle_error(request, 500, exc) resp, reset = await self.finish_response(request, resp, start_time) else: resp, reset = await self.finish_response(request, resp, start_time) finally: self._request_in_progress = False if self._handler_waiter is not None: self._handler_waiter.set_result(None) return resp, reset async def start(self) -> None: """Process incoming request. It reads request line, request headers and request payload, then calls handle_request() method. Subclass has to override handle_request(). start() handles various exceptions in request or response handling. Connection is being closed always unless keep_alive(True) specified. """ loop = self._loop manager = self._manager assert manager is not None keepalive_timeout = self._keepalive_timeout resp = None assert self._request_factory is not None assert self._request_handler is not None while not self._force_close: if not self._messages: try: # wait for next request self._waiter = loop.create_future() await self._waiter finally: self._waiter = None message, payload = self._messages.popleft() # time is only fetched if logging is enabled as otherwise # its thrown away and never used. start = loop.time() if self._logging_enabled else None manager.requests_count += 1 writer = StreamWriter(self, loop) if not isinstance(message, _ErrInfo): request_handler = self._request_handler else: # make request_factory work request_handler = self._make_error_handler(message) message = ERROR # Important don't hold a reference to the current task # as on traceback it will prevent the task from being # collected and will cause a memory leak. request = self._request_factory( message, payload, self, writer, self._task_handler or asyncio.current_task(loop), # type: ignore[arg-type] ) try: # a new task is used for copy context vars (#3406) coro = self._handle_request(request, start, request_handler) if sys.version_info >= (3, 12): task = asyncio.Task(coro, loop=loop, eager_start=True) else: task = loop.create_task(coro) try: resp, reset = await task except ConnectionError: self.log_debug("Ignored premature client disconnection") break # Drop the processed task from asyncio.Task.all_tasks() early del task # https://github.com/python/mypy/issues/14309 if reset: # type: ignore[possibly-undefined] self.log_debug("Ignored premature client disconnection 2") break # notify server about keep-alive self._keepalive = bool(resp.keep_alive) # check payload if not payload.is_eof(): lingering_time = self._lingering_time # Could be force closed while awaiting above tasks. if not self._force_close and lingering_time: # type: ignore[redundant-expr] self.log_debug( "Start lingering close timer for %s sec.", lingering_time ) now = loop.time() end_t = now + lingering_time try: while not payload.is_eof() and now < end_t: async with ceil_timeout(end_t - now): # read and ignore await payload.readany() now = loop.time() except (asyncio.CancelledError, asyncio.TimeoutError): if ( sys.version_info >= (3, 11) and (t := asyncio.current_task()) and t.cancelling() ): raise # if payload still uncompleted if not payload.is_eof() and not self._force_close: self.log_debug("Uncompleted request.") self.close() payload.set_exception(_PAYLOAD_ACCESS_ERROR) except asyncio.CancelledError: self.log_debug("Ignored premature client disconnection") self.force_close() raise except Exception as exc: self.log_exception("Unhandled exception", exc_info=exc) self.force_close() except BaseException: self.force_close() raise finally: request._task = None # type: ignore[assignment] # Break reference cycle in case of exception if self.transport is None and resp is not None: self.log_debug("Ignored premature client disconnection.") if self._keepalive and not self._close and not self._force_close: # start keep-alive timer close_time = loop.time() + keepalive_timeout self._next_keepalive_close_time = close_time if self._keepalive_handle is None: self._keepalive_handle = loop.call_at( close_time, self._process_keepalive ) else: break # remove handler, close transport if no handlers left if not self._force_close: self._task_handler = None if self.transport is not None: self.transport.close() async def finish_response( self, request: BaseRequest, resp: StreamResponse, start_time: float | None ) -> tuple[StreamResponse, bool]: """Prepare the response and write_eof, then log access. This has to be called within the context of any exception so the access logger can get exception information. Returns True if the client disconnects prematurely. """ request._finish() if self._request_parser is not None: self._request_parser.set_upgraded(False) self._upgrade = False if self._message_tail: self._request_parser.feed_data(self._message_tail) self._message_tail = b"" try: prepare_meth = resp.prepare except AttributeError: if resp is None: self.log_exception("Missing return statement on request handler") # type: ignore[unreachable] else: self.log_exception( f"Web-handler should return a response instance, got {resp!r}" ) exc = HTTPInternalServerError() resp = Response( status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers ) prepare_meth = resp.prepare try: await prepare_meth(request) await resp.write_eof() except ConnectionError: await self.log_access(request, resp, start_time) return resp, True await self.log_access(request, resp, start_time) return resp, False def handle_error( self, request: BaseRequest, status: int = 500, exc: BaseException | None = None, message: str | None = None, ) -> StreamResponse: """Handle errors. Returns HTTP response with specific status code. Logs additional information. It always closes current connection. """ if self._request_count == 1 and isinstance(exc, BadHttpMethod): # BadHttpMethod is common when a client sends non-HTTP # or encrypted traffic to an HTTP port. This is expected # to happen when connected to the public internet so we log # it at the debug level as to not fill logs with noise. self.logger.debug( "Error handling request from %s", request.remote, exc_info=exc ) else: self.log_exception( "Error handling request from %s", request.remote, exc_info=exc ) # some data already got sent, connection is broken if request.writer.output_size > 0: raise ConnectionError( "Response is sent already, cannot send another response " "with the error message" ) ct = "text/plain" if status == HTTPStatus.INTERNAL_SERVER_ERROR: title = f"{HTTPStatus.INTERNAL_SERVER_ERROR.value} {HTTPStatus.INTERNAL_SERVER_ERROR.phrase}" msg = HTTPStatus.INTERNAL_SERVER_ERROR.description tb = None if self._loop.get_debug(): with suppress(Exception): tb = traceback.format_exc() if "text/html" in request.headers.get("Accept", ""): if tb: tb = html_escape(tb) msg = f"

Traceback:

\n
{tb}
" message = ( "" f"{title}" f"\n

{title}

" f"\n{msg}\n\n" ) ct = "text/html" else: if tb: msg = tb message = title + "\n\n" + msg resp = Response(status=status, text=message, content_type=ct) resp.force_close() return resp def _make_error_handler( self, err_info: _ErrInfo ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]: async def handler(request: BaseRequest) -> StreamResponse: return self.handle_error( request, err_info.status, err_info.exc, err_info.message ) return handler ================================================ FILE: aiohttp/web_request.py ================================================ import asyncio import datetime import io import re import socket import string import sys import tempfile import types from collections.abc import Iterator, Mapping, MutableMapping from re import Pattern from types import MappingProxyType from typing import TYPE_CHECKING, Any, Final, Optional, TypeVar, cast, overload from urllib.parse import parse_qsl from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy from yarl import URL from . import hdrs from ._cookie_helpers import parse_cookie_header from .abc import AbstractStreamWriter from .helpers import ( _SENTINEL, ETAG_ANY, LIST_QUOTED_ETAG_RE, ChainMapProxy, ETag, HeadersMixin, RequestKey, frozen_dataclass_decorator, is_expected_content_type, parse_http_date, reify, sentinel, set_exception, ) from .http_parser import RawRequestMessage from .http_writer import HttpVersion from .multipart import BodyPartReader, MultipartReader from .streams import EmptyStreamReader, StreamReader from .typedefs import ( DEFAULT_JSON_DECODER, JSONDecoder, LooseHeaders, RawHeaders, StrOrURL, ) from .web_exceptions import ( HTTPBadRequest, HTTPRequestEntityTooLarge, HTTPUnsupportedMediaType, ) from .web_response import StreamResponse if sys.version_info >= (3, 11): from typing import Self else: Self = Any __all__ = ("BaseRequest", "FileField", "Request") if TYPE_CHECKING: from .web_app import Application from .web_protocol import RequestHandler from .web_urldispatcher import UrlMappingMatchInfo _T = TypeVar("_T") @frozen_dataclass_decorator class FileField: name: str filename: str file: io.BufferedReader content_type: str headers: CIMultiDictProxy[str] _TCHAR: Final[str] = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-" # '-' at the end to prevent interpretation as range in a char class _TOKEN: Final[str] = rf"[{_TCHAR}]+" _QDTEXT: Final[str] = r"[{}]".format( r"".join(chr(c) for c in (0x09, 0x20, 0x21) + tuple(range(0x23, 0x7F))) ) # qdtext includes 0x5C to escape 0x5D ('\]') # qdtext excludes obs-text (because obsoleted, and encoding not specified) _QUOTED_PAIR: Final[str] = r"\\[\t !-~]" _QUOTED_STRING: Final[str] = rf'"(?:{_QUOTED_PAIR}|{_QDTEXT})*"' # This does not have a ReDOS/performance concern as long as it used with re.match(). _FORWARDED_PAIR: Final[str] = rf"({_TOKEN})=({_TOKEN}|{_QUOTED_STRING})(:\d{{1,4}})?" _QUOTED_PAIR_REPLACE_RE: Final[Pattern[str]] = re.compile(r"\\([\t !-~])") # same pattern as _QUOTED_PAIR but contains a capture group _FORWARDED_PAIR_RE: Final[Pattern[str]] = re.compile(_FORWARDED_PAIR) ############################################################ # HTTP Request ############################################################ class BaseRequest(MutableMapping[str | RequestKey[Any], Any], HeadersMixin): POST_METHODS = { hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT, hdrs.METH_TRACE, hdrs.METH_DELETE, } _post: MultiDictProxy[str | bytes | FileField] | None = None _read_bytes: bytes | None = None def __init__( self, message: RawRequestMessage, payload: StreamReader, protocol: "RequestHandler[Self]", payload_writer: AbstractStreamWriter, task: "asyncio.Task[None]", loop: asyncio.AbstractEventLoop, *, client_max_size: int = 1024**2, state: dict[RequestKey[Any] | str, Any] | None = None, scheme: str | None = None, host: str | None = None, remote: str | None = None, ) -> None: self._message = message self._protocol = protocol self._payload_writer = payload_writer self._payload = payload self._headers: CIMultiDictProxy[str] = message.headers self._method = message.method self._version = message.version self._cache: dict[str, Any] = {} url = message.url if url.absolute: if scheme is not None: url = url.with_scheme(scheme) if host is not None: url = url.with_host(host) # absolute URL is given, # override auto-calculating url, host, and scheme # all other properties should be good self._cache["url"] = url self._cache["host"] = url.host self._cache["scheme"] = url.scheme self._rel_url = url.relative() else: self._rel_url = url if scheme is not None: self._cache["scheme"] = scheme if host is not None: self._cache["host"] = host self._state = {} if state is None else state self._task = task self._client_max_size = client_max_size self._loop = loop self._transport_sslcontext = protocol.ssl_context self._transport_peername = protocol.peername if remote is not None: self._cache["remote"] = remote def clone( self, *, method: str | _SENTINEL = sentinel, rel_url: StrOrURL | _SENTINEL = sentinel, headers: LooseHeaders | _SENTINEL = sentinel, scheme: str | _SENTINEL = sentinel, host: str | _SENTINEL = sentinel, remote: str | _SENTINEL = sentinel, client_max_size: int | _SENTINEL = sentinel, ) -> "BaseRequest": """Clone itself with replacement some attributes. Creates and returns a new instance of Request object. If no parameters are given, an exact copy is returned. If a parameter is not passed, it will reuse the one from the current request object. """ if self._read_bytes: raise RuntimeError("Cannot clone request after reading its content") dct: dict[str, Any] = {} if method is not sentinel: dct["method"] = method if rel_url is not sentinel: new_url: URL = URL(rel_url) dct["url"] = new_url dct["path"] = str(new_url) if headers is not sentinel: # a copy semantic new_headers = CIMultiDictProxy(CIMultiDict(headers)) dct["headers"] = new_headers dct["raw_headers"] = tuple( (k.encode("utf-8"), v.encode("utf-8")) for k, v in new_headers.items() ) message = self._message._replace(**dct) kwargs: dict[str, str] = {} if scheme is not sentinel: kwargs["scheme"] = scheme if host is not sentinel: kwargs["host"] = host if remote is not sentinel: kwargs["remote"] = remote if client_max_size is sentinel: client_max_size = self._client_max_size return self.__class__( message, self._payload, self._protocol, # type: ignore[arg-type] self._payload_writer, self._task, self._loop, client_max_size=client_max_size, state=self._state.copy(), **kwargs, ) @property def task(self) -> "asyncio.Task[None]": return self._task @property def protocol(self) -> "RequestHandler[Self]": return self._protocol @property def transport(self) -> asyncio.Transport | None: return self._protocol.transport @property def writer(self) -> AbstractStreamWriter: return self._payload_writer @property def client_max_size(self) -> int: return self._client_max_size @reify def rel_url(self) -> URL: return self._rel_url # MutableMapping API @overload # type: ignore[override] def __getitem__(self, key: RequestKey[_T]) -> _T: ... @overload def __getitem__(self, key: str) -> Any: ... def __getitem__(self, key: str | RequestKey[_T]) -> Any: return self._state[key] @overload # type: ignore[override] def __setitem__(self, key: RequestKey[_T], value: _T) -> None: ... @overload def __setitem__(self, key: str, value: Any) -> None: ... def __setitem__(self, key: str | RequestKey[_T], value: Any) -> None: self._state[key] = value def __delitem__(self, key: str | RequestKey[_T]) -> None: del self._state[key] def __len__(self) -> int: return len(self._state) def __iter__(self) -> Iterator[str | RequestKey[Any]]: return iter(self._state) ######## @reify def secure(self) -> bool: """A bool indicating if the request is handled with SSL.""" return self.scheme == "https" @reify def forwarded(self) -> tuple[Mapping[str, str], ...]: """A tuple containing all parsed Forwarded header(s). Makes an effort to parse Forwarded headers as specified by RFC 7239: - It adds one (immutable) dictionary per Forwarded 'field-value', ie per proxy. The element corresponds to the data in the Forwarded field-value added by the first proxy encountered by the client. Each subsequent item corresponds to those added by later proxies. - It checks that every value has valid syntax in general as specified in section 4: either a 'token' or a 'quoted-string'. - It un-escapes found escape sequences. - It does NOT validate 'by' and 'for' contents as specified in section 6. - It does NOT validate 'host' contents (Host ABNF). - It does NOT validate 'proto' contents for valid URI scheme names. Returns a tuple containing one or more immutable dicts """ elems = [] for field_value in self._message.headers.getall(hdrs.FORWARDED, ()): length = len(field_value) pos = 0 need_separator = False elem: dict[str, str] = {} elems.append(types.MappingProxyType(elem)) while 0 <= pos < length: match = _FORWARDED_PAIR_RE.match(field_value, pos) if match is not None: # got a valid forwarded-pair if need_separator: # bad syntax here, skip to next comma pos = field_value.find(",", pos) else: name, value, port = match.groups() if value[0] == '"': # quoted string: remove quotes and unescape value = _QUOTED_PAIR_REPLACE_RE.sub(r"\1", value[1:-1]) if port: value += port elem[name.lower()] = value pos += len(match.group(0)) need_separator = True elif field_value[pos] == ",": # next forwarded-element need_separator = False elem = {} elems.append(types.MappingProxyType(elem)) pos += 1 elif field_value[pos] == ";": # next forwarded-pair need_separator = False pos += 1 elif field_value[pos] in " \t": # Allow whitespace even between forwarded-pairs, though # RFC 7239 doesn't. This simplifies code and is in line # with Postel's law. pos += 1 else: # bad syntax here, skip to next comma pos = field_value.find(",", pos) return tuple(elems) @reify def scheme(self) -> str: """A string representing the scheme of the request. Hostname is resolved in this order: - overridden value by .clone(scheme=new_scheme) call. - type of connection to peer: HTTPS if socket is SSL, HTTP otherwise. 'http' or 'https'. """ if self._transport_sslcontext: return "https" else: return "http" @reify def method(self) -> str: """Read only property for getting HTTP method. The value is upper-cased str like 'GET', 'POST', 'PUT' etc. """ return self._method @reify def version(self) -> HttpVersion: """Read only property for getting HTTP version of request. Returns aiohttp.protocol.HttpVersion instance. """ return self._version @reify def host(self) -> str: """Hostname of the request. Hostname is resolved in this order: - overridden value by .clone(host=new_host) call. - HOST HTTP header - socket.getfqdn() value For example, 'example.com' or 'localhost:8080'. For historical reasons, the port number may be included. """ host = self._message.headers.get(hdrs.HOST) if host is not None: return host return socket.getfqdn() @reify def remote(self) -> str | None: """Remote IP of client initiated HTTP request. The IP is resolved in this order: - overridden value by .clone(remote=new_remote) call. - peername of opened socket """ if self._transport_peername is None: return None if isinstance(self._transport_peername, (list, tuple)): return str(self._transport_peername[0]) return str(self._transport_peername) @reify def url(self) -> URL: """The full URL of the request.""" # authority is used here because it may include the port number # and we want yarl to parse it correctly return URL.build(scheme=self.scheme, authority=self.host).join(self._rel_url) @reify def path(self) -> str: """The URL including *PATH INFO* without the host or scheme. E.g., ``/app/blog`` """ return self._rel_url.path @reify def path_qs(self) -> str: """The URL including PATH_INFO and the query string. E.g, /app/blog?id=10 """ return str(self._rel_url) @reify def raw_path(self) -> str: """The URL including raw *PATH INFO* without the host or scheme. Warning, the path is unquoted and may contains non valid URL characters E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters`` """ return self._message.path @reify def query(self) -> MultiDictProxy[str]: """A multidict with all the variables in the query string.""" return self._rel_url.query @reify def query_string(self) -> str: """The query string in the URL. E.g., id=10 """ return self._rel_url.query_string @reify def headers(self) -> CIMultiDictProxy[str]: """A case-insensitive multidict proxy with all headers.""" return self._headers @reify def raw_headers(self) -> RawHeaders: """A sequence of pairs for all headers.""" return self._message.raw_headers @reify def if_modified_since(self) -> datetime.datetime | None: """The value of If-Modified-Since HTTP header, or None. This header is represented as a `datetime` object. """ return parse_http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE)) @reify def if_unmodified_since(self) -> datetime.datetime | None: """The value of If-Unmodified-Since HTTP header, or None. This header is represented as a `datetime` object. """ return parse_http_date(self.headers.get(hdrs.IF_UNMODIFIED_SINCE)) @staticmethod def _etag_values(etag_header: str) -> Iterator[ETag]: """Extract `ETag` objects from raw header.""" if etag_header == ETAG_ANY: yield ETag( is_weak=False, value=ETAG_ANY, ) else: for match in LIST_QUOTED_ETAG_RE.finditer(etag_header): is_weak, value, garbage = match.group(2, 3, 4) # Any symbol captured by 4th group means # that the following sequence is invalid. if garbage: break yield ETag( is_weak=bool(is_weak), value=value, ) @classmethod def _if_match_or_none_impl( cls, header_value: str | None ) -> tuple[ETag, ...] | None: if not header_value: return None return tuple(cls._etag_values(header_value)) @reify def if_match(self) -> tuple[ETag, ...] | None: """The value of If-Match HTTP header, or None. This header is represented as a `tuple` of `ETag` objects. """ return self._if_match_or_none_impl(self.headers.get(hdrs.IF_MATCH)) @reify def if_none_match(self) -> tuple[ETag, ...] | None: """The value of If-None-Match HTTP header, or None. This header is represented as a `tuple` of `ETag` objects. """ return self._if_match_or_none_impl(self.headers.get(hdrs.IF_NONE_MATCH)) @reify def if_range(self) -> datetime.datetime | None: """The value of If-Range HTTP header, or None. This header is represented as a `datetime` object. """ return parse_http_date(self.headers.get(hdrs.IF_RANGE)) @reify def keep_alive(self) -> bool: """Is keepalive enabled by client?""" return not self._message.should_close @reify def cookies(self) -> Mapping[str, str]: """Return request cookies. A read-only dictionary-like object. """ # Use parse_cookie_header for RFC 6265 compliant Cookie header parsing # that accepts special characters in cookie names (fixes #2683) parsed = parse_cookie_header(self.headers.get(hdrs.COOKIE, "")) # Extract values from Morsel objects return MappingProxyType({name: morsel.value for name, morsel in parsed}) @reify def http_range(self) -> "slice[int, int, int]": """The content of Range HTTP header. Return a slice instance. """ rng = self._headers.get(hdrs.RANGE) start, end = None, None if rng is not None: try: pattern = r"^bytes=(\d*)-(\d*)$" start, end = re.findall(pattern, rng, re.ASCII)[0] except IndexError: # pattern was not found in header raise ValueError("range not in acceptable format") end = int(end) if end else None start = int(start) if start else None if start is None and end is not None: # end with no start is to return tail of content start = -end end = None if start is not None and end is not None: # end is inclusive in range header, exclusive for slice end += 1 if start >= end: raise ValueError("start cannot be after end") if start is end is None: # No valid range supplied raise ValueError("No start or end of range specified") return slice(start, end, 1) @reify def content(self) -> StreamReader: """Return raw payload stream.""" return self._payload @property def can_read_body(self) -> bool: """Return True if request's HTTP BODY can be read, False otherwise.""" return not self._payload.at_eof() @reify def body_exists(self) -> bool: """Return True if request has HTTP BODY, False otherwise.""" return type(self._payload) is not EmptyStreamReader async def release(self) -> None: """Release request. Eat unread part of HTTP BODY if present. """ while not self._payload.at_eof(): await self._payload.readany() async def read(self) -> bytes: """Read request body if present. Returns bytes object with full request content. """ if self._read_bytes is None: body = bytearray() while True: chunk = await self._payload.readany() body.extend(chunk) if self._client_max_size: body_size = len(body) if body_size > self._client_max_size: raise HTTPRequestEntityTooLarge( max_size=self._client_max_size, actual_size=body_size ) if not chunk: break self._read_bytes = bytes(body) return self._read_bytes async def text(self) -> str: """Return BODY as text using encoding from .charset.""" bytes_body = await self.read() encoding = self.charset or "utf-8" try: return bytes_body.decode(encoding) except LookupError: raise HTTPUnsupportedMediaType() async def json( self, *, loads: JSONDecoder = DEFAULT_JSON_DECODER, content_type: str | None = "application/json", ) -> Any: """Return BODY as JSON.""" body = await self.text() if content_type: if not is_expected_content_type(self.content_type, content_type): raise HTTPBadRequest( text=( "Attempt to decode JSON with " "unexpected mimetype: %s" % self.content_type ) ) return loads(body) async def multipart(self) -> MultipartReader: """Return async iterator to process BODY as multipart.""" return MultipartReader( self._headers, self._payload, max_field_size=self._protocol.max_field_size, max_headers=self._protocol.max_headers, ) async def post(self) -> "MultiDictProxy[str | bytes | FileField]": """Return POST parameters.""" if self._post is not None: return self._post if self._method not in self.POST_METHODS: self._post = MultiDictProxy(MultiDict()) return self._post content_type = self.content_type if content_type not in ( "", "application/x-www-form-urlencoded", "multipart/form-data", ): self._post = MultiDictProxy(MultiDict()) return self._post out: MultiDict[str | bytes | FileField] = MultiDict() if content_type == "multipart/form-data": multipart = await self.multipart() max_size = self._client_max_size size = 0 while (field := await multipart.next()) is not None: field_ct = field.headers.get(hdrs.CONTENT_TYPE) if isinstance(field, BodyPartReader): if field.name is None: raise ValueError("Multipart field missing name.") # Note that according to RFC 7578, the Content-Type header # is optional, even for files, so we can't assume it's # present. # https://tools.ietf.org/html/rfc7578#section-4.4 if field.filename: # store file in temp file tmp = await self._loop.run_in_executor( None, tempfile.TemporaryFile ) while chunk := await field.read_chunk(size=2**18): async for decoded_chunk in field.decode_iter(chunk): await self._loop.run_in_executor( None, tmp.write, decoded_chunk ) size += len(decoded_chunk) if 0 < max_size < size: await self._loop.run_in_executor(None, tmp.close) raise HTTPRequestEntityTooLarge( max_size=max_size, actual_size=size ) await self._loop.run_in_executor(None, tmp.seek, 0) if field_ct is None: field_ct = "application/octet-stream" ff = FileField( field.name, field.filename, cast(io.BufferedReader, tmp), field_ct, field.headers, ) out.add(field.name, ff) else: # deal with ordinary data raw_data = bytearray() while chunk := await field.read_chunk(): size += len(chunk) if 0 < max_size < size: raise HTTPRequestEntityTooLarge( max_size=max_size, actual_size=size ) raw_data.extend(chunk) value = bytearray() # form-data doesn't support compression, so don't need to check size again. async for d in field.decode_iter(raw_data): # type: ignore[arg-type] value.extend(d) if field_ct is None or field_ct.startswith("text/"): charset = field.get_charset(default="utf-8") out.add(field.name, value.decode(charset)) else: out.add(field.name, value) # type: ignore[arg-type] else: raise ValueError( "To decode nested multipart you need to use custom reader", ) else: data = await self.read() if data: charset = self.charset or "utf-8" bytes_query = data.rstrip() try: query = bytes_query.decode(charset) except LookupError: raise HTTPUnsupportedMediaType() out.extend( parse_qsl(qs=query, keep_blank_values=True, encoding=charset) ) self._post = MultiDictProxy(out) return self._post def get_extra_info(self, name: str, default: Any = None) -> Any: """Extra info from protocol transport""" transport = self._protocol.transport if transport is None: return default return transport.get_extra_info(name, default) def __repr__(self) -> str: ascii_encodable_path = self.path.encode("ascii", "backslashreplace").decode( "ascii" ) return f"<{self.__class__.__name__} {self._method} {ascii_encodable_path} >" def __eq__(self, other: object) -> bool: return id(self) == id(other) def __bool__(self) -> bool: return True async def _prepare_hook(self, response: StreamResponse) -> None: return def _cancel(self, exc: BaseException) -> None: set_exception(self._payload, exc) def _finish(self) -> None: if self._post is None or self.content_type != "multipart/form-data": return # NOTE: Release file descriptors for the # NOTE: `tempfile.Temporaryfile`-created `_io.BufferedRandom` # NOTE: instances of files sent within multipart request body # NOTE: via HTTP POST request. for file_name, file_field_object in self._post.items(): if isinstance(file_field_object, FileField): file_field_object.file.close() class Request(BaseRequest): _match_info: Optional["UrlMappingMatchInfo"] = None def clone( self, *, method: str | _SENTINEL = sentinel, rel_url: StrOrURL | _SENTINEL = sentinel, headers: LooseHeaders | _SENTINEL = sentinel, scheme: str | _SENTINEL = sentinel, host: str | _SENTINEL = sentinel, remote: str | _SENTINEL = sentinel, client_max_size: int | _SENTINEL = sentinel, ) -> "Request": ret = super().clone( method=method, rel_url=rel_url, headers=headers, scheme=scheme, host=host, remote=remote, client_max_size=client_max_size, ) new_ret = cast(Request, ret) new_ret._match_info = self._match_info return new_ret @reify def match_info(self) -> "UrlMappingMatchInfo": """Result of route resolving.""" match_info = self._match_info assert match_info is not None return match_info @property def app(self) -> "Application": """Application instance.""" match_info = self._match_info assert match_info is not None return match_info.current_app @property def config_dict(self) -> ChainMapProxy: match_info = self._match_info assert match_info is not None lst = match_info.apps app = self.app idx = lst.index(app) sublist = list(reversed(lst[: idx + 1])) return ChainMapProxy(sublist) async def _prepare_hook(self, response: StreamResponse) -> None: match_info = self._match_info if match_info is None: return for app in match_info._apps: if on_response_prepare := app.on_response_prepare: await on_response_prepare.send(self, response) ================================================ FILE: aiohttp/web_response.py ================================================ import asyncio import datetime import enum import json import math import time import warnings from collections.abc import Iterator, MutableMapping from concurrent.futures import Executor from http import HTTPStatus from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast, overload from multidict import CIMultiDict, istr from . import hdrs, payload from .abc import AbstractStreamWriter from .compression_utils import ZLibCompressor from .helpers import ( ETAG_ANY, QUOTED_ETAG_RE, CookieMixin, ETag, HeadersMixin, ResponseKey, must_be_empty_body, parse_http_date, populate_with_cookies, rfc822_formatted_time, sentinel, should_remove_content_length, validate_etag_value, ) from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11 from .payload import Payload from .typedefs import JSONBytesEncoder, JSONEncoder, LooseHeaders REASON_PHRASES = {http_status.value: http_status.phrase for http_status in HTTPStatus} LARGE_BODY_SIZE = 1024**2 __all__ = ( "ContentCoding", "StreamResponse", "Response", "json_response", "json_bytes_response", ) if TYPE_CHECKING: from .web_request import BaseRequest _T = TypeVar("_T") # TODO(py311): Convert to StrEnum for wider use class ContentCoding(enum.Enum): # The content codings that we have support for. # # Additional registered codings are listed at: # https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding deflate = "deflate" gzip = "gzip" identity = "identity" CONTENT_CODINGS = {coding.value: coding for coding in ContentCoding} ############################################################ # HTTP Response classes ############################################################ class StreamResponse( MutableMapping[str | ResponseKey[Any], Any], HeadersMixin, CookieMixin ): _body: None | bytes | bytearray | Payload _length_check = True _body = None _keep_alive: bool | None = None _chunked: bool = False _compression: bool = False _compression_strategy: int | None = None _compression_force: ContentCoding | None = None _req: Optional["BaseRequest"] = None _payload_writer: AbstractStreamWriter | None = None _eof_sent: bool = False _must_be_empty_body: bool | None = None _body_length = 0 _send_headers_immediately = True def __init__( self, *, status: int = 200, reason: str | None = None, headers: LooseHeaders | None = None, _real_headers: CIMultiDict[str] | None = None, ) -> None: """Initialize a new stream response object. _real_headers is an internal parameter used to pass a pre-populated headers object. It is used by the `Response` class to avoid copying the headers when creating a new response object. It is not intended to be used by external code. """ self._state: dict[str | ResponseKey[Any], Any] = {} if _real_headers is not None: self._headers = _real_headers elif headers is not None: self._headers: CIMultiDict[str] = CIMultiDict(headers) else: self._headers = CIMultiDict() self._set_status(status, reason) @property def prepared(self) -> bool: return self._eof_sent or self._payload_writer is not None @property def task(self) -> "asyncio.Task[None] | None": if self._req: return self._req.task else: return None @property def status(self) -> int: return self._status @property def chunked(self) -> bool: return self._chunked @property def compression(self) -> bool: return self._compression @property def reason(self) -> str: return self._reason def set_status( self, status: int, reason: str | None = None, ) -> None: assert ( not self.prepared ), "Cannot change the response status code after the headers have been sent" self._set_status(status, reason) def _set_status(self, status: int, reason: str | None) -> None: self._status = status if reason is None: reason = REASON_PHRASES.get(self._status, "") elif "\r" in reason or "\n" in reason: raise ValueError("Reason cannot contain \\r or \\n") self._reason = reason @property def keep_alive(self) -> bool | None: return self._keep_alive def force_close(self) -> None: self._keep_alive = False @property def body_length(self) -> int: return self._body_length def enable_chunked_encoding(self) -> None: """Enables automatic chunked transfer encoding.""" if hdrs.CONTENT_LENGTH in self._headers: raise RuntimeError( "You can't enable chunked encoding when a content length is set" ) self._chunked = True def enable_compression( self, force: ContentCoding | None = None, strategy: int | None = None, ) -> None: """Enables response compression encoding.""" # Don't enable compression if content is already encoded. # This prevents double compression and provides a safe, predictable behavior # without breaking existing code that may call enable_compression() on # responses that already have Content-Encoding set (e.g., FileResponse # serving pre-compressed files). if hdrs.CONTENT_ENCODING in self._headers: return self._compression = True self._compression_force = force self._compression_strategy = strategy @property def headers(self) -> "CIMultiDict[str]": return self._headers @property def content_length(self) -> int | None: # Just a placeholder for adding setter return super().content_length @content_length.setter def content_length(self, value: int | None) -> None: if value is not None: value = int(value) if self._chunked: raise RuntimeError( "You can't set content length when chunked encoding is enable" ) self._headers[hdrs.CONTENT_LENGTH] = str(value) else: self._headers.pop(hdrs.CONTENT_LENGTH, None) @property def content_type(self) -> str: # Just a placeholder for adding setter return super().content_type @content_type.setter def content_type(self, value: str) -> None: self.content_type # read header values if needed self._content_type = str(value) self._generate_content_type_header() @property def charset(self) -> str | None: # Just a placeholder for adding setter return super().charset @charset.setter def charset(self, value: str | None) -> None: ctype = self.content_type # read header values if needed if ctype == "application/octet-stream": raise RuntimeError( "Setting charset for application/octet-stream " "doesn't make sense, setup content_type first" ) assert self._content_dict is not None if value is None: self._content_dict.pop("charset", None) else: self._content_dict["charset"] = str(value).lower() self._generate_content_type_header() @property def last_modified(self) -> datetime.datetime | None: """The value of Last-Modified HTTP header, or None. This header is represented as a `datetime` object. """ return parse_http_date(self._headers.get(hdrs.LAST_MODIFIED)) @last_modified.setter def last_modified( self, value: int | float | datetime.datetime | str | None ) -> None: if value is None: self._headers.pop(hdrs.LAST_MODIFIED, None) elif isinstance(value, (int, float)): self._headers[hdrs.LAST_MODIFIED] = time.strftime( "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)) ) elif isinstance(value, datetime.datetime): self._headers[hdrs.LAST_MODIFIED] = time.strftime( "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple() ) elif isinstance(value, str): self._headers[hdrs.LAST_MODIFIED] = value else: msg = f"Unsupported type for last_modified: {type(value).__name__}" # type: ignore[unreachable] raise TypeError(msg) @property def etag(self) -> ETag | None: quoted_value = self._headers.get(hdrs.ETAG) if not quoted_value: return None elif quoted_value == ETAG_ANY: return ETag(value=ETAG_ANY) match = QUOTED_ETAG_RE.fullmatch(quoted_value) if not match: return None is_weak, value = match.group(1, 2) return ETag( is_weak=bool(is_weak), value=value, ) @etag.setter def etag(self, value: ETag | str | None) -> None: if value is None: self._headers.pop(hdrs.ETAG, None) elif (isinstance(value, str) and value == ETAG_ANY) or ( isinstance(value, ETag) and value.value == ETAG_ANY ): self._headers[hdrs.ETAG] = ETAG_ANY elif isinstance(value, str): validate_etag_value(value) self._headers[hdrs.ETAG] = f'"{value}"' elif isinstance(value, ETag) and isinstance(value.value, str): # type: ignore[redundant-expr] validate_etag_value(value.value) hdr_value = f'W/"{value.value}"' if value.is_weak else f'"{value.value}"' self._headers[hdrs.ETAG] = hdr_value else: raise ValueError( f"Unsupported etag type: {type(value)}. " f"etag must be str, ETag or None" ) def _generate_content_type_header( self, CONTENT_TYPE: istr = hdrs.CONTENT_TYPE ) -> None: assert self._content_dict is not None assert self._content_type is not None params = "; ".join(f"{k}={v}" for k, v in self._content_dict.items()) if params: ctype = self._content_type + "; " + params else: ctype = self._content_type self._headers[CONTENT_TYPE] = ctype async def _do_start_compression(self, coding: ContentCoding) -> None: if coding is ContentCoding.identity: return assert self._payload_writer is not None self._headers[hdrs.CONTENT_ENCODING] = coding.value self._payload_writer.enable_compression( coding.value, self._compression_strategy ) # Compressed payload may have different content length, # remove the header self._headers.popall(hdrs.CONTENT_LENGTH, None) async def _start_compression(self, request: "BaseRequest") -> None: if self._compression_force: await self._do_start_compression(self._compression_force) return # Encoding comparisons should be case-insensitive # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() for value, coding in CONTENT_CODINGS.items(): if value in accept_encoding: await self._do_start_compression(coding) return async def prepare(self, request: "BaseRequest") -> AbstractStreamWriter | None: if self._eof_sent: return None if self._payload_writer is not None: return self._payload_writer self._must_be_empty_body = must_be_empty_body(request.method, self.status) return await self._start(request) async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: self._req = request writer = self._payload_writer = request._payload_writer await self._prepare_headers() await request._prepare_hook(self) await self._write_headers() return writer async def _prepare_headers(self) -> None: request = self._req assert request is not None writer = self._payload_writer assert writer is not None keep_alive = self._keep_alive if keep_alive is None: keep_alive = request.keep_alive self._keep_alive = keep_alive version = request.version headers = self._headers if self._cookies: populate_with_cookies(headers, self._cookies) if self._compression: await self._start_compression(request) if self._chunked: if version != HttpVersion11: raise RuntimeError( "Using chunked encoding is forbidden " f"for HTTP/{request.version.major}.{request.version.minor}" ) if not self._must_be_empty_body: writer.enable_chunking() headers[hdrs.TRANSFER_ENCODING] = "chunked" elif self._length_check: # Disabled for WebSockets writer.length = self.content_length if writer.length is None: if version >= HttpVersion11: if not self._must_be_empty_body: writer.enable_chunking() headers[hdrs.TRANSFER_ENCODING] = "chunked" elif not self._must_be_empty_body: keep_alive = False # HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2 # HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4 if self._must_be_empty_body: if hdrs.CONTENT_LENGTH in headers and should_remove_content_length( request.method, self.status ): del headers[hdrs.CONTENT_LENGTH] # https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-10 # https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-13 if hdrs.TRANSFER_ENCODING in headers: del headers[hdrs.TRANSFER_ENCODING] elif (writer.length if self._length_check else self.content_length) != 0: # https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5 headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream") headers.setdefault(hdrs.DATE, rfc822_formatted_time()) headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE) # connection header if hdrs.CONNECTION not in headers: if keep_alive: if version == HttpVersion10: headers[hdrs.CONNECTION] = "keep-alive" elif version == HttpVersion11: headers[hdrs.CONNECTION] = "close" async def _write_headers(self) -> None: request = self._req assert request is not None writer = self._payload_writer assert writer is not None # status line version = request.version status_line = f"HTTP/{version[0]}.{version[1]} {self._status} {self._reason}" await writer.write_headers(status_line, self._headers) # Send headers immediately if not opted into buffering if self._send_headers_immediately: writer.send_headers() async def write( self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] ) -> None: assert isinstance( data, (bytes, bytearray, memoryview) ), "data argument must be byte-ish (%r)" % type(data) if self._eof_sent: raise RuntimeError("Cannot call write() after write_eof()") if self._payload_writer is None: raise RuntimeError("Cannot call write() before prepare()") await self._payload_writer.write(data) async def drain(self) -> None: assert not self._eof_sent, "EOF has already been sent" assert self._payload_writer is not None, "Response has not been started" warnings.warn( "drain method is deprecated, use await resp.write()", DeprecationWarning, stacklevel=2, ) await self._payload_writer.drain() async def write_eof(self, data: bytes = b"") -> None: assert isinstance( data, (bytes, bytearray, memoryview) ), "data argument must be byte-ish (%r)" % type(data) if self._eof_sent: return assert self._payload_writer is not None, "Response has not been started" await self._payload_writer.write_eof(data) self._eof_sent = True self._req = None self._body_length = self._payload_writer.output_size self._payload_writer = None def __repr__(self) -> str: if self._eof_sent: info = "eof" elif self.prepared: assert self._req is not None info = f"{self._req.method} {self._req.path} " else: info = "not prepared" return f"<{self.__class__.__name__} {self.reason} {info}>" @overload # type: ignore[override] def __getitem__(self, key: ResponseKey[_T]) -> _T: ... @overload def __getitem__(self, key: str) -> Any: ... def __getitem__(self, key: str | ResponseKey[_T]) -> Any: return self._state[key] @overload # type: ignore[override] def __setitem__(self, key: ResponseKey[_T], value: _T) -> None: ... @overload def __setitem__(self, key: str, value: Any) -> None: ... def __setitem__(self, key: str | ResponseKey[_T], value: Any) -> None: self._state[key] = value def __delitem__(self, key: str | ResponseKey[_T]) -> None: del self._state[key] def __len__(self) -> int: return len(self._state) def __iter__(self) -> Iterator[str | ResponseKey[Any]]: return iter(self._state) def __hash__(self) -> int: return hash(id(self)) def __eq__(self, other: object) -> bool: return self is other def __bool__(self) -> bool: return True class Response(StreamResponse): _compressed_body: bytes | None = None _send_headers_immediately = False def __init__( self, *, body: Any = None, status: int = 200, reason: str | None = None, text: str | None = None, headers: LooseHeaders | None = None, content_type: str | None = None, charset: str | None = None, zlib_executor_size: int | None = None, zlib_executor: Executor | None = None, ) -> None: if body is not None and text is not None: raise ValueError("body and text are not allowed together") if headers is None: real_headers: CIMultiDict[str] = CIMultiDict() else: real_headers = CIMultiDict(headers) if content_type is not None and "charset" in content_type: raise ValueError("charset must not be in content_type argument") if text is not None: if hdrs.CONTENT_TYPE in real_headers: if content_type or charset: raise ValueError( "passing both Content-Type header and " "content_type or charset params " "is forbidden" ) else: # fast path for filling headers if not isinstance(text, str): raise TypeError("text argument must be str (%r)" % type(text)) if content_type is None: content_type = "text/plain" if charset is None: charset = "utf-8" real_headers[hdrs.CONTENT_TYPE] = content_type + "; charset=" + charset body = text.encode(charset) text = None elif hdrs.CONTENT_TYPE in real_headers: if content_type is not None or charset is not None: raise ValueError( "passing both Content-Type header and " "content_type or charset params " "is forbidden" ) elif content_type is not None: if charset is not None: content_type += "; charset=" + charset real_headers[hdrs.CONTENT_TYPE] = content_type super().__init__(status=status, reason=reason, _real_headers=real_headers) if text is not None: self.text = text else: self.body = body self._zlib_executor_size = zlib_executor_size self._zlib_executor = zlib_executor @property def body(self) -> bytes | bytearray | Payload | None: return self._body @body.setter def body(self, body: Any) -> None: if body is None: self._body = None elif isinstance(body, (bytes, bytearray)): self._body = body else: try: self._body = body = payload.PAYLOAD_REGISTRY.get(body) except payload.LookupError: raise ValueError("Unsupported body type %r" % type(body)) headers = self._headers # set content-type if hdrs.CONTENT_TYPE not in headers: headers[hdrs.CONTENT_TYPE] = body.content_type # copy payload headers if body.headers: for key, value in body.headers.items(): if key not in headers: headers[key] = value self._compressed_body = None @property def text(self) -> str | None: if self._body is None: return None # Note: When _body is a Payload (e.g. FilePayload), this may do blocking I/O # This is generally safe as most common payloads (BytesPayload, StringPayload) # don't do blocking I/O, but be careful with file-based payloads return self._body.decode(self.charset or "utf-8") @text.setter def text(self, text: str) -> None: assert isinstance(text, str), "text argument must be str (%r)" % type(text) if self.content_type == "application/octet-stream": self.content_type = "text/plain" if self.charset is None: self.charset = "utf-8" self._body = text.encode(self.charset) self._compressed_body = None @property def content_length(self) -> int | None: if self._chunked: return None if hdrs.CONTENT_LENGTH in self._headers: return int(self._headers[hdrs.CONTENT_LENGTH]) if self._compressed_body is not None: # Return length of the compressed body return len(self._compressed_body) elif isinstance(self._body, Payload): # A payload without content length, or a compressed payload return None elif self._body is not None: return len(self._body) else: return 0 @content_length.setter def content_length(self, value: int | None) -> None: raise RuntimeError("Content length is set automatically") async def write_eof(self, data: bytes = b"") -> None: if self._eof_sent: return if self._compressed_body is None: body = self._body else: body = self._compressed_body assert not data, f"data arg is not supported, got {data!r}" assert self._req is not None assert self._payload_writer is not None if body is None or self._must_be_empty_body: await super().write_eof() elif isinstance(self._body, Payload): await self._body.write(self._payload_writer) await self._body.close() await super().write_eof() else: await super().write_eof(cast(bytes, body)) async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: if hdrs.CONTENT_LENGTH in self._headers: if should_remove_content_length(request.method, self.status): del self._headers[hdrs.CONTENT_LENGTH] elif not self._chunked: if isinstance(self._body, Payload): if (size := self._body.size) is not None: self._headers[hdrs.CONTENT_LENGTH] = str(size) else: body_len = len(self._body) if self._body else "0" # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-7 if body_len != "0" or ( self.status != 304 and request.method not in hdrs.METH_HEAD_ALL ): self._headers[hdrs.CONTENT_LENGTH] = str(body_len) return await super()._start(request) async def _do_start_compression(self, coding: ContentCoding) -> None: if self._chunked or isinstance(self._body, Payload): return await super()._do_start_compression(coding) if coding is ContentCoding.identity: return # Instead of using _payload_writer.enable_compression, # compress the whole body compressor = ZLibCompressor( encoding=coding.value, max_sync_chunk_size=self._zlib_executor_size, executor=self._zlib_executor, ) assert self._body is not None if self._zlib_executor_size is None and len(self._body) > LARGE_BODY_SIZE: warnings.warn( "Synchronous compression of large response bodies " f"({len(self._body)} bytes) might block the async event loop. " "Consider providing a custom value to zlib_executor_size/" "zlib_executor response properties or disabling compression on it." ) self._compressed_body = ( await compressor.compress(self._body) + compressor.flush() ) self._headers[hdrs.CONTENT_ENCODING] = coding.value self._headers[hdrs.CONTENT_LENGTH] = str(len(self._compressed_body)) def json_response( data: Any = sentinel, *, text: str | None = None, body: bytes | None = None, status: int = 200, reason: str | None = None, headers: LooseHeaders | None = None, content_type: str = "application/json", dumps: JSONEncoder = json.dumps, ) -> Response: if data is not sentinel: if text or body: raise ValueError("only one of data, text, or body should be specified") else: text = dumps(data) return Response( text=text, body=body, status=status, reason=reason, headers=headers, content_type=content_type, ) def json_bytes_response( data: Any = sentinel, *, dumps: JSONBytesEncoder, body: bytes | None = None, status: int = 200, reason: str | None = None, headers: LooseHeaders | None = None, content_type: str = "application/json", ) -> Response: """Create a JSON response using a bytes-returning encoder. Use this when your JSON encoder (like orjson) returns bytes instead of str, avoiding the encode/decode overhead. """ if data is not sentinel: if body is not None: raise ValueError("only one of data or body should be specified") else: body = dumps(data) return Response( body=body, status=status, reason=reason, headers=headers, content_type=content_type, ) ================================================ FILE: aiohttp/web_routedef.py ================================================ import abc import dataclasses from collections.abc import Callable, Iterator, Sequence from typing import TYPE_CHECKING, Any, Union, overload from . import hdrs from .abc import AbstractView from .typedefs import Handler, PathLike if TYPE_CHECKING: from .web_request import Request from .web_response import StreamResponse from .web_urldispatcher import AbstractRoute, UrlDispatcher else: Request = StreamResponse = UrlDispatcher = AbstractRoute = None __all__ = ( "AbstractRouteDef", "RouteDef", "StaticDef", "RouteTableDef", "head", "options", "get", "post", "patch", "put", "delete", "route", "view", "static", ) class AbstractRouteDef(abc.ABC): @abc.abstractmethod def register(self, router: UrlDispatcher) -> list[AbstractRoute]: """Register itself into the given router.""" _HandlerType = Union[type[AbstractView], Handler] @dataclasses.dataclass(frozen=True, repr=False) class RouteDef(AbstractRouteDef): method: str path: str handler: _HandlerType kwargs: dict[str, Any] def __repr__(self) -> str: info = [] for name, value in sorted(self.kwargs.items()): info.append(f", {name}={value!r}") return " {handler.__name__!r}{info}>".format( method=self.method, path=self.path, handler=self.handler, info="".join(info) ) def register(self, router: UrlDispatcher) -> list[AbstractRoute]: if self.method in hdrs.METH_ALL: reg = getattr(router, "add_" + self.method.lower()) return [reg(self.path, self.handler, **self.kwargs)] else: return [ router.add_route(self.method, self.path, self.handler, **self.kwargs) ] @dataclasses.dataclass(frozen=True, repr=False) class StaticDef(AbstractRouteDef): prefix: str path: PathLike kwargs: dict[str, Any] def __repr__(self) -> str: info = [] for name, value in sorted(self.kwargs.items()): info.append(f", {name}={value!r}") return " {path}{info}>".format( prefix=self.prefix, path=self.path, info="".join(info) ) def register(self, router: UrlDispatcher) -> list[AbstractRoute]: resource = router.add_static(self.prefix, self.path, **self.kwargs) routes = resource.get_info().get("routes", {}) return list(routes.values()) def route(method: str, path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return RouteDef(method, path, handler, kwargs) def head(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return route(hdrs.METH_HEAD, path, handler, **kwargs) def options(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return route(hdrs.METH_OPTIONS, path, handler, **kwargs) def get( path: str, handler: _HandlerType, *, name: str | None = None, allow_head: bool = True, **kwargs: Any, ) -> RouteDef: return route( hdrs.METH_GET, path, handler, name=name, allow_head=allow_head, **kwargs ) def post(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return route(hdrs.METH_POST, path, handler, **kwargs) def put(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return route(hdrs.METH_PUT, path, handler, **kwargs) def patch(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return route(hdrs.METH_PATCH, path, handler, **kwargs) def delete(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return route(hdrs.METH_DELETE, path, handler, **kwargs) def view(path: str, handler: type[AbstractView], **kwargs: Any) -> RouteDef: return route(hdrs.METH_ANY, path, handler, **kwargs) def static(prefix: str, path: PathLike, **kwargs: Any) -> StaticDef: return StaticDef(prefix, path, kwargs) _Deco = Callable[[_HandlerType], _HandlerType] class RouteTableDef(Sequence[AbstractRouteDef]): """Route definition table""" def __init__(self) -> None: self._items: list[AbstractRouteDef] = [] def __repr__(self) -> str: return f"" @overload def __getitem__(self, index: int) -> AbstractRouteDef: ... @overload def __getitem__(self, index: "slice[int, int, int]") -> list[AbstractRouteDef]: ... def __getitem__( self, index: Union[int, "slice[int, int, int]"] ) -> AbstractRouteDef | list[AbstractRouteDef]: return self._items[index] def __iter__(self) -> Iterator[AbstractRouteDef]: return iter(self._items) def __len__(self) -> int: return len(self._items) def __contains__(self, item: object) -> bool: return item in self._items def route(self, method: str, path: str, **kwargs: Any) -> _Deco: def inner(handler: _HandlerType) -> _HandlerType: self._items.append(RouteDef(method, path, handler, kwargs)) return handler return inner def head(self, path: str, **kwargs: Any) -> _Deco: return self.route(hdrs.METH_HEAD, path, **kwargs) def get(self, path: str, **kwargs: Any) -> _Deco: return self.route(hdrs.METH_GET, path, **kwargs) def post(self, path: str, **kwargs: Any) -> _Deco: return self.route(hdrs.METH_POST, path, **kwargs) def put(self, path: str, **kwargs: Any) -> _Deco: return self.route(hdrs.METH_PUT, path, **kwargs) def patch(self, path: str, **kwargs: Any) -> _Deco: return self.route(hdrs.METH_PATCH, path, **kwargs) def delete(self, path: str, **kwargs: Any) -> _Deco: return self.route(hdrs.METH_DELETE, path, **kwargs) def options(self, path: str, **kwargs: Any) -> _Deco: return self.route(hdrs.METH_OPTIONS, path, **kwargs) def view(self, path: str, **kwargs: Any) -> _Deco: return self.route(hdrs.METH_ANY, path, **kwargs) def static(self, prefix: str, path: PathLike, **kwargs: Any) -> None: self._items.append(StaticDef(prefix, path, kwargs)) ================================================ FILE: aiohttp/web_runner.py ================================================ import asyncio import signal import socket from abc import ABC, abstractmethod from typing import Any, Generic, TypeVar from yarl import URL from .abc import AbstractAccessLogger, AbstractStreamWriter from .http_parser import RawRequestMessage from .streams import StreamReader from .typedefs import PathLike from .web_app import Application from .web_log import AccessLogger from .web_protocol import RequestHandler from .web_request import BaseRequest, Request from .web_server import Server try: from ssl import SSLContext except ImportError: # pragma: no cover SSLContext = object # type: ignore[misc,assignment] __all__ = ( "BaseSite", "TCPSite", "UnixSite", "NamedPipeSite", "SockSite", "BaseRunner", "AppRunner", "ServerRunner", "GracefulExit", ) _Request = TypeVar("_Request", bound=BaseRequest) class GracefulExit(SystemExit): code = 1 def _raise_graceful_exit() -> None: raise GracefulExit() class BaseSite(ABC): __slots__ = ("_runner", "_ssl_context", "_backlog", "_server") def __init__( self, runner: "BaseRunner[Any]", *, ssl_context: SSLContext | None = None, backlog: int = 128, ) -> None: if runner.server is None: raise RuntimeError("Call runner.setup() before making a site") self._runner = runner self._ssl_context = ssl_context self._backlog = backlog self._server: asyncio.Server | None = None @property @abstractmethod def name(self) -> str: """Return the name of the site (e.g. a URL).""" @abstractmethod async def start(self) -> None: self._runner._reg_site(self) async def stop(self) -> None: self._runner._check_site(self) if self._server is not None: # Maybe not started yet self._server.close() self._runner._unreg_site(self) class TCPSite(BaseSite): __slots__ = ("_host", "_port", "_bound_port", "_reuse_address", "_reuse_port") def __init__( self, runner: "BaseRunner[Any]", host: str | None = None, port: int | None = None, *, ssl_context: SSLContext | None = None, backlog: int = 128, reuse_address: bool | None = None, reuse_port: bool | None = None, ) -> None: super().__init__( runner, ssl_context=ssl_context, backlog=backlog, ) self._host = host if port is None: port = 8443 if self._ssl_context else 8080 self._port = port self._bound_port: int | None = None self._reuse_address = reuse_address self._reuse_port = reuse_port @property def port(self) -> int: """The port the server is listening on. If the server hasn't been started yet, this returns the requested port (which might be 0 for a dynamic port). After the server starts, it returns the actual bound port. This is especially useful when port=0 was requested, as it allows retrieving the dynamically assigned port after the site has started. """ if self._bound_port is not None: return self._bound_port return self._port @property def name(self) -> str: scheme = "https" if self._ssl_context else "http" host = "0.0.0.0" if not self._host else self._host return str(URL.build(scheme=scheme, host=host, port=self.port)) async def start(self) -> None: await super().start() loop = asyncio.get_event_loop() server = self._runner.server assert server is not None self._server = await loop.create_server( server, self._host, self._port, ssl=self._ssl_context, backlog=self._backlog, reuse_address=self._reuse_address, reuse_port=self._reuse_port, ) if self._server.sockets: self._bound_port = self._server.sockets[0].getsockname()[1] else: self._bound_port = self._port class UnixSite(BaseSite): __slots__ = ("_path",) def __init__( self, runner: "BaseRunner[Any]", path: PathLike, *, ssl_context: SSLContext | None = None, backlog: int = 128, ) -> None: super().__init__( runner, ssl_context=ssl_context, backlog=backlog, ) self._path = path @property def name(self) -> str: scheme = "https" if self._ssl_context else "http" return f"{scheme}://unix:{self._path}:" async def start(self) -> None: await super().start() loop = asyncio.get_event_loop() server = self._runner.server assert server is not None self._server = await loop.create_unix_server( server, self._path, ssl=self._ssl_context, backlog=self._backlog, ) class NamedPipeSite(BaseSite): __slots__ = ("_path",) def __init__(self, runner: "BaseRunner[Any]", path: str) -> None: loop = asyncio.get_event_loop() if not isinstance( loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] ): raise RuntimeError( "Named Pipes only available in proactor loop under windows" ) super().__init__(runner) self._path = path @property def name(self) -> str: return self._path async def start(self) -> None: await super().start() loop = asyncio.get_event_loop() server = self._runner.server assert server is not None _server = await loop.start_serving_pipe( # type: ignore[attr-defined] server, self._path ) self._server = _server[0] class SockSite(BaseSite): __slots__ = ("_sock", "_name") def __init__( self, runner: "BaseRunner[Any]", sock: socket.socket, *, ssl_context: SSLContext | None = None, backlog: int = 128, ) -> None: super().__init__( runner, ssl_context=ssl_context, backlog=backlog, ) self._sock = sock scheme = "https" if self._ssl_context else "http" if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: name = f"{scheme}://unix:{sock.getsockname()}:" else: host, port = sock.getsockname()[:2] name = str(URL.build(scheme=scheme, host=host, port=port)) self._name = name @property def name(self) -> str: return self._name async def start(self) -> None: await super().start() loop = asyncio.get_event_loop() server = self._runner.server assert server is not None self._server = await loop.create_server( server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog ) class BaseRunner(ABC, Generic[_Request]): __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout") def __init__( self, *, handle_signals: bool = False, shutdown_timeout: float = 60.0, **kwargs: Any, ) -> None: self._handle_signals = handle_signals self._kwargs = kwargs self._server: Server[_Request] | None = None self._sites: list[BaseSite] = [] self._shutdown_timeout = shutdown_timeout @property def server(self) -> Server[_Request] | None: return self._server @property def addresses(self) -> list[Any]: ret: list[Any] = [] for site in self._sites: server = site._server if server is not None: sockets = server.sockets if sockets is not None: for sock in sockets: ret.append(sock.getsockname()) return ret @property def sites(self) -> set[BaseSite]: return set(self._sites) async def setup(self) -> None: loop = asyncio.get_event_loop() if self._handle_signals: try: loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit) loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit) except NotImplementedError: # add_signal_handler is not implemented on Windows pass self._server = await self._make_server() @abstractmethod async def shutdown(self) -> None: """Call any shutdown hooks to help server close gracefully.""" async def cleanup(self) -> None: # The loop over sites is intentional, an exception on gather() # leaves self._sites in unpredictable state. # The loop guarantees that a site is either deleted on success or # still present on failure for site in list(self._sites): await site.stop() if self._server: # If setup succeeded # Yield to event loop to ensure incoming requests prior to stopping the sites # have all started to be handled before we proceed to close idle connections. await asyncio.sleep(0) self._server.pre_shutdown() await self.shutdown() await self._server.shutdown(self._shutdown_timeout) await self._cleanup_server() self._server = None if self._handle_signals: loop = asyncio.get_running_loop() try: loop.remove_signal_handler(signal.SIGINT) loop.remove_signal_handler(signal.SIGTERM) except NotImplementedError: # remove_signal_handler is not implemented on Windows pass @abstractmethod async def _make_server(self) -> Server[_Request]: """Return a new server for the runner to serve requests.""" @abstractmethod async def _cleanup_server(self) -> None: """Run any cleanup steps after the server is shutdown.""" def _reg_site(self, site: BaseSite) -> None: if site in self._sites: raise RuntimeError(f"Site {site} is already registered in runner {self}") self._sites.append(site) def _check_site(self, site: BaseSite) -> None: if site not in self._sites: raise RuntimeError(f"Site {site} is not registered in runner {self}") def _unreg_site(self, site: BaseSite) -> None: if site not in self._sites: raise RuntimeError(f"Site {site} is not registered in runner {self}") self._sites.remove(site) class ServerRunner(BaseRunner[BaseRequest]): """Low-level web server runner""" __slots__ = ("_web_server",) def __init__( self, web_server: Server[BaseRequest], *, handle_signals: bool = False, **kwargs: Any, ) -> None: super().__init__(handle_signals=handle_signals, **kwargs) self._web_server = web_server async def shutdown(self) -> None: pass async def _make_server(self) -> Server[BaseRequest]: return self._web_server async def _cleanup_server(self) -> None: pass class AppRunner(BaseRunner[Request]): """Web Application runner""" __slots__ = ("_app",) def __init__( self, app: Application, *, handle_signals: bool = False, access_log_class: type[AbstractAccessLogger] = AccessLogger, **kwargs: Any, ) -> None: if not isinstance(app, Application): raise TypeError( f"The first argument should be web.Application instance, got {app!r}" ) kwargs["access_log_class"] = access_log_class if app._handler_args: for k, v in app._handler_args.items(): kwargs[k] = v if not issubclass(kwargs["access_log_class"], AbstractAccessLogger): raise TypeError( "access_log_class must be subclass of " "aiohttp.abc.AbstractAccessLogger, got {}".format( kwargs["access_log_class"] ) ) super().__init__(handle_signals=handle_signals, **kwargs) self._app = app @property def app(self) -> Application: return self._app async def shutdown(self) -> None: await self._app.shutdown() async def _make_server(self) -> Server[Request]: self._app.on_startup.freeze() await self._app.startup() self._app.freeze() return Server( self._app._handle, request_factory=self._make_request, **self._kwargs, ) def _make_request( self, message: RawRequestMessage, payload: StreamReader, protocol: RequestHandler[Request], writer: AbstractStreamWriter, task: "asyncio.Task[None]", _cls: type[Request] = Request, ) -> Request: loop = asyncio.get_running_loop() return _cls( message, payload, protocol, writer, task, loop, client_max_size=self.app._client_max_size, ) async def _cleanup_server(self) -> None: await self._app.cleanup() ================================================ FILE: aiohttp/web_server.py ================================================ """Low level HTTP server.""" import asyncio import warnings from collections.abc import Awaitable, Callable from typing import Any, Generic, TypeVar, overload from .abc import AbstractStreamWriter from .http_parser import RawRequestMessage from .streams import StreamReader from .web_protocol import RequestHandler from .web_request import BaseRequest from .web_response import StreamResponse __all__ = ("Server",) _Request = TypeVar("_Request", bound=BaseRequest) _RequestFactory = Callable[ [ RawRequestMessage, StreamReader, "RequestHandler[_Request]", AbstractStreamWriter, "asyncio.Task[None]", ], _Request, ] class Server(Generic[_Request]): request_factory: _RequestFactory[_Request] @overload def __init__( self: "Server[BaseRequest]", handler: Callable[[_Request], Awaitable[StreamResponse]], *, debug: bool | None = None, handler_cancellation: bool = False, **kwargs: Any, # TODO(PY311): Use Unpack to define kwargs from RequestHandler ) -> None: ... @overload def __init__( self, handler: Callable[[_Request], Awaitable[StreamResponse]], *, request_factory: _RequestFactory[_Request] | None, debug: bool | None = None, handler_cancellation: bool = False, **kwargs: Any, ) -> None: ... def __init__( self, handler: Callable[[_Request], Awaitable[StreamResponse]], *, request_factory: _RequestFactory[_Request] | None = None, debug: bool | None = None, handler_cancellation: bool = False, **kwargs: Any, ) -> None: if debug is not None: warnings.warn( "debug argument is no-op since 4.0 and scheduled for removal in 5.0", DeprecationWarning, stacklevel=2, ) self._loop = asyncio.get_running_loop() self._connections: dict[RequestHandler[_Request], asyncio.Transport] = {} self._kwargs = kwargs # requests_count is the number of requests being processed by the server # for the lifetime of the server. self.requests_count = 0 self.request_handler = handler self.request_factory = request_factory or self._make_request # type: ignore[assignment] self.handler_cancellation = handler_cancellation @property def connections(self) -> list[RequestHandler[_Request]]: return list(self._connections.keys()) def connection_made( self, handler: RequestHandler[_Request], transport: asyncio.Transport ) -> None: self._connections[handler] = transport def connection_lost( self, handler: RequestHandler[_Request], exc: BaseException | None = None ) -> None: if handler in self._connections: if handler._task_handler: handler._task_handler.add_done_callback( lambda f: self._connections.pop(handler, None) ) else: del self._connections[handler] def _make_request( self, message: RawRequestMessage, payload: StreamReader, protocol: RequestHandler[BaseRequest], writer: AbstractStreamWriter, task: "asyncio.Task[None]", ) -> BaseRequest: return BaseRequest(message, payload, protocol, writer, task, self._loop) def pre_shutdown(self) -> None: for conn in self._connections: conn.close() async def shutdown(self, timeout: float | None = None) -> None: coros = (conn.shutdown(timeout) for conn in self._connections) await asyncio.gather(*coros) self._connections.clear() def __call__(self) -> RequestHandler[_Request]: try: return RequestHandler(self, loop=self._loop, **self._kwargs) except TypeError: # Failsafe creation: remove all custom handler_args kwargs = { k: v for k, v in self._kwargs.items() if k in ["debug", "access_log_class"] } return RequestHandler(self, loop=self._loop, **kwargs) ================================================ FILE: aiohttp/web_urldispatcher.py ================================================ import abc import asyncio import base64 import functools import hashlib import html import inspect import keyword import os import platform import re import sys from collections.abc import ( Awaitable, Callable, Container, Generator, Iterable, Iterator, Mapping, Sized, ) from pathlib import Path from re import Pattern from types import MappingProxyType from typing import TYPE_CHECKING, Any, Final, NoReturn, Optional, TypedDict, cast from yarl import URL from . import hdrs from .abc import AbstractMatchInfo, AbstractRouter, AbstractView from .helpers import DEBUG from .http import HttpVersion11 from .typedefs import Handler, PathLike from .web_exceptions import ( HTTPException, HTTPExpectationFailed, HTTPForbidden, HTTPMethodNotAllowed, HTTPNotFound, ) from .web_fileresponse import FileResponse from .web_request import Request from .web_response import Response, StreamResponse from .web_routedef import AbstractRouteDef __all__ = ( "UrlDispatcher", "UrlMappingMatchInfo", "AbstractResource", "Resource", "PlainResource", "DynamicResource", "AbstractRoute", "ResourceRoute", "StaticResource", "View", ) if TYPE_CHECKING: from .web_app import Application CIRCULAR_SYMLINK_ERROR = (RuntimeError,) if sys.version_info < (3, 13) else () HTTP_METHOD_RE: Final[Pattern[str]] = re.compile( r"^[0-9A-Za-z!#\$%&'\*\+\-\.\^_`\|~]+$" ) ROUTE_RE: Final[Pattern[str]] = re.compile( r"(\{[_a-zA-Z][^{}]*(?:\{[^{}]*\}[^{}]*)*\})" ) PATH_SEP: Final[str] = re.escape("/") IS_WINDOWS: Final[bool] = platform.system() == "Windows" _ExpectHandler = Callable[[Request], Awaitable[StreamResponse | None]] _Resolve = tuple[Optional["UrlMappingMatchInfo"], set[str]] html_escape = functools.partial(html.escape, quote=True) class _InfoDict(TypedDict, total=False): path: str formatter: str pattern: Pattern[str] directory: Path prefix: str routes: Mapping[str, "AbstractRoute"] app: "Application" domain: str rule: "AbstractRuleMatching" http_exception: HTTPException class AbstractResource(Sized, Iterable["AbstractRoute"]): def __init__(self, *, name: str | None = None) -> None: self._name = name @property def name(self) -> str | None: return self._name @property @abc.abstractmethod def canonical(self) -> str: """Exposes the resource's canonical path. For example '/foo/bar/{name}' """ @abc.abstractmethod # pragma: no branch def url_for(self, **kwargs: str) -> URL: """Construct url for resource with additional params.""" @abc.abstractmethod # pragma: no branch async def resolve(self, request: Request) -> _Resolve: """Resolve resource. Return (UrlMappingMatchInfo, allowed_methods) pair. """ @abc.abstractmethod def add_prefix(self, prefix: str) -> None: """Add a prefix to processed URLs. Required for subapplications support. """ @abc.abstractmethod def get_info(self) -> _InfoDict: """Return a dict with additional info useful for introspection""" def freeze(self) -> None: pass @abc.abstractmethod def raw_match(self, path: str) -> bool: """Perform a raw match against path""" class AbstractRoute(abc.ABC): def __init__( self, method: str, handler: Handler | type[AbstractView], *, expect_handler: _ExpectHandler | None = None, resource: AbstractResource | None = None, ) -> None: if expect_handler is None: expect_handler = _default_expect_handler assert inspect.iscoroutinefunction(expect_handler) or ( sys.version_info < (3, 14) and asyncio.iscoroutinefunction(expect_handler) ), f"Coroutine is expected, got {expect_handler!r}" method = method.upper() if not HTTP_METHOD_RE.match(method): raise ValueError(f"{method} is not allowed HTTP method") if inspect.iscoroutinefunction(handler) or ( sys.version_info < (3, 14) and asyncio.iscoroutinefunction(handler) ): pass elif isinstance(handler, type) and issubclass(handler, AbstractView): pass else: raise TypeError( f"Only async functions are allowed as web-handlers, got {handler!r}" ) self._method = method self._handler = handler self._expect_handler = expect_handler self._resource = resource @property def method(self) -> str: return self._method @property def handler(self) -> Handler: return self._handler @property @abc.abstractmethod def name(self) -> str | None: """Optional route's name, always equals to resource's name.""" @property def resource(self) -> AbstractResource | None: return self._resource @abc.abstractmethod def get_info(self) -> _InfoDict: """Return a dict with additional info useful for introspection""" @abc.abstractmethod # pragma: no branch def url_for(self, *args: str, **kwargs: str) -> URL: """Construct url for route with additional params.""" async def handle_expect_header(self, request: Request) -> StreamResponse | None: return await self._expect_handler(request) class UrlMappingMatchInfo(dict[str, str], AbstractMatchInfo): __slots__ = ("_route", "_apps", "_current_app", "_frozen") def __init__(self, match_dict: dict[str, str], route: AbstractRoute) -> None: super().__init__(match_dict) self._route = route self._apps: list[Application] = [] self._current_app: Application | None = None self._frozen = False @property def handler(self) -> Handler: return self._route.handler @property def route(self) -> AbstractRoute: return self._route @property def expect_handler(self) -> _ExpectHandler: return self._route.handle_expect_header @property def http_exception(self) -> HTTPException | None: return None def get_info(self) -> _InfoDict: # type: ignore[override] return self._route.get_info() @property def apps(self) -> tuple["Application", ...]: return tuple(self._apps) def add_app(self, app: "Application") -> None: if self._frozen: raise RuntimeError("Cannot change apps stack after .freeze() call") if self._current_app is None: self._current_app = app self._apps.insert(0, app) @property def current_app(self) -> "Application": app = self._current_app assert app is not None return app @current_app.setter def current_app(self, app: "Application") -> None: if DEBUG: if app not in self._apps: raise RuntimeError( f"Expected one of the following apps {self._apps!r}, got {app!r}" ) self._current_app = app def freeze(self) -> None: self._frozen = True def __repr__(self) -> str: return f"" class MatchInfoError(UrlMappingMatchInfo): __slots__ = ("_exception",) def __init__(self, http_exception: HTTPException) -> None: self._exception = http_exception super().__init__({}, SystemRoute(self._exception)) @property def http_exception(self) -> HTTPException: return self._exception def __repr__(self) -> str: return f"" async def _default_expect_handler(request: Request) -> None: """Default handler for Expect header. Just send "100 Continue" to client. raise HTTPExpectationFailed if value of header is not "100-continue" """ expect = request.headers.get(hdrs.EXPECT, "") if request.version == HttpVersion11: if expect.lower() == "100-continue": await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") # Reset output_size as we haven't started the main body yet. request.writer.output_size = 0 else: raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect) class Resource(AbstractResource): def __init__(self, *, name: str | None = None) -> None: super().__init__(name=name) self._routes: dict[str, ResourceRoute] = {} self._any_route: ResourceRoute | None = None self._allowed_methods: set[str] = set() def add_route( self, method: str, handler: type[AbstractView] | Handler, *, expect_handler: _ExpectHandler | None = None, ) -> "ResourceRoute": if route := self._routes.get(method, self._any_route): raise RuntimeError( "Added route will never be executed, " f"method {route.method} is already " "registered" ) route_obj = ResourceRoute(method, handler, self, expect_handler=expect_handler) self.register_route(route_obj) return route_obj def register_route(self, route: "ResourceRoute") -> None: assert isinstance( route, ResourceRoute ), f"Instance of Route class is required, got {route!r}" if route.method == hdrs.METH_ANY: self._any_route = route self._allowed_methods.add(route.method) self._routes[route.method] = route async def resolve(self, request: Request) -> _Resolve: if (match_dict := self._match(request.rel_url.path_safe)) is None: return None, set() if route := self._routes.get(request.method, self._any_route): return UrlMappingMatchInfo(match_dict, route), self._allowed_methods return None, self._allowed_methods @abc.abstractmethod def _match(self, path: str) -> dict[str, str] | None: """Return dict of path values if path matches this resource, otherwise None.""" def __len__(self) -> int: return len(self._routes) def __iter__(self) -> Iterator["ResourceRoute"]: return iter(self._routes.values()) # TODO: implement all abstract methods class PlainResource(Resource): def __init__(self, path: str, *, name: str | None = None) -> None: super().__init__(name=name) assert not path or path.startswith("/") self._path = path @property def canonical(self) -> str: return self._path def freeze(self) -> None: if not self._path: self._path = "/" def add_prefix(self, prefix: str) -> None: assert prefix.startswith("/") assert not prefix.endswith("/") assert len(prefix) > 1 self._path = prefix + self._path def _match(self, path: str) -> dict[str, str] | None: # string comparison is about 10 times faster than regexp matching if self._path == path: return {} return None def raw_match(self, path: str) -> bool: return self._path == path def get_info(self) -> _InfoDict: return {"path": self._path} def url_for(self) -> URL: # type: ignore[override] return URL.build(path=self._path, encoded=True) def __repr__(self) -> str: name = "'" + self.name + "' " if self.name is not None else "" return f"" class DynamicResource(Resource): DYN = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*)\}") DYN_WITH_RE = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*):(?P.+)\}") GOOD = r"[^{}/]+" def __init__(self, path: str, *, name: str | None = None) -> None: super().__init__(name=name) self._orig_path = path pattern = "" formatter = "" for part in ROUTE_RE.split(path): match = self.DYN.fullmatch(part) if match: pattern += "(?P<{}>{})".format(match.group("var"), self.GOOD) formatter += "{" + match.group("var") + "}" continue match = self.DYN_WITH_RE.fullmatch(part) if match: pattern += "(?P<{var}>{re})".format(**match.groupdict()) formatter += "{" + match.group("var") + "}" continue if "{" in part or "}" in part: raise ValueError(f"Invalid path '{path}'['{part}']") part = _requote_path(part) formatter += part pattern += re.escape(part) try: compiled = re.compile(pattern) except re.error as exc: raise ValueError(f"Bad pattern '{pattern}': {exc}") from None assert compiled.pattern.startswith(PATH_SEP) assert formatter.startswith("/") self._pattern = compiled self._formatter = formatter @property def canonical(self) -> str: return self._formatter def add_prefix(self, prefix: str) -> None: assert prefix.startswith("/") assert not prefix.endswith("/") assert len(prefix) > 1 self._pattern = re.compile(re.escape(prefix) + self._pattern.pattern) self._formatter = prefix + self._formatter def _match(self, path: str) -> dict[str, str] | None: match = self._pattern.fullmatch(path) if match is None: return None return { key: _unquote_path_safe(value) for key, value in match.groupdict().items() } def raw_match(self, path: str) -> bool: return self._orig_path == path def get_info(self) -> _InfoDict: return {"formatter": self._formatter, "pattern": self._pattern} def url_for(self, **parts: str) -> URL: url = self._formatter.format_map({k: _quote_path(v) for k, v in parts.items()}) return URL.build(path=url, encoded=True) def __repr__(self) -> str: name = "'" + self.name + "' " if self.name is not None else "" return f"" class PrefixResource(AbstractResource): def __init__(self, prefix: str, *, name: str | None = None) -> None: assert not prefix or prefix.startswith("/"), prefix assert prefix in ("", "/") or not prefix.endswith("/"), prefix super().__init__(name=name) self._prefix = _requote_path(prefix) self._prefix2 = self._prefix + "/" @property def canonical(self) -> str: return self._prefix def add_prefix(self, prefix: str) -> None: assert prefix.startswith("/") assert not prefix.endswith("/") assert len(prefix) > 1 self._prefix = prefix + self._prefix self._prefix2 = self._prefix + "/" def raw_match(self, prefix: str) -> bool: return False # TODO: impl missing abstract methods class StaticResource(PrefixResource): VERSION_KEY = "v" def __init__( self, prefix: str, directory: PathLike, *, name: str | None = None, expect_handler: _ExpectHandler | None = None, chunk_size: int = 256 * 1024, show_index: bool = False, follow_symlinks: bool = False, append_version: bool = False, ) -> None: super().__init__(prefix, name=name) try: directory = Path(directory).expanduser().resolve(strict=True) except FileNotFoundError as error: raise ValueError(f"'{directory}' does not exist") from error if not directory.is_dir(): raise ValueError(f"'{directory}' is not a directory") self._directory = directory self._show_index = show_index self._chunk_size = chunk_size self._follow_symlinks = follow_symlinks self._expect_handler = expect_handler self._append_version = append_version self._routes = { "GET": ResourceRoute( "GET", self._handle, self, expect_handler=expect_handler ), "HEAD": ResourceRoute( "HEAD", self._handle, self, expect_handler=expect_handler ), } self._allowed_methods = set(self._routes) def url_for( # type: ignore[override] self, *, filename: PathLike, append_version: bool | None = None, ) -> URL: if append_version is None: append_version = self._append_version filename = str(filename).lstrip("/") url = URL.build(path=self._prefix, encoded=True) # filename is not encoded url = url / filename if append_version: unresolved_path = self._directory.joinpath(filename) try: if self._follow_symlinks: normalized_path = Path(os.path.normpath(unresolved_path)) normalized_path.relative_to(self._directory) filepath = normalized_path.resolve() else: filepath = unresolved_path.resolve() filepath.relative_to(self._directory) except (ValueError, FileNotFoundError): # ValueError for case when path point to symlink # with follow_symlinks is False return url # relatively safe if filepath.is_file(): # TODO cache file content # with file watcher for cache invalidation with filepath.open("rb") as f: file_bytes = f.read() h = self._get_file_hash(file_bytes) url = url.with_query({self.VERSION_KEY: h}) return url return url @staticmethod def _get_file_hash(byte_array: bytes) -> str: m = hashlib.sha256() # todo sha256 can be configurable param m.update(byte_array) b64 = base64.urlsafe_b64encode(m.digest()) return b64.decode("ascii") def get_info(self) -> _InfoDict: return { "directory": self._directory, "prefix": self._prefix, "routes": self._routes, } def set_options_route(self, handler: Handler) -> None: if "OPTIONS" in self._routes: raise RuntimeError("OPTIONS route was set already") self._routes["OPTIONS"] = ResourceRoute( "OPTIONS", handler, self, expect_handler=self._expect_handler ) self._allowed_methods.add("OPTIONS") async def resolve(self, request: Request) -> _Resolve: path = request.rel_url.path_safe method = request.method # We normalise here to avoid matches that traverse below the static root. # e.g. /static/../../../../home/user/webapp/static/ norm_path = os.path.normpath(path) if IS_WINDOWS: norm_path = norm_path.replace("\\", "/") if not norm_path.startswith(self._prefix2) and norm_path != self._prefix: return None, set() allowed_methods = self._allowed_methods if method not in allowed_methods: return None, allowed_methods match_dict = {"filename": _unquote_path_safe(path[len(self._prefix) + 1 :])} return (UrlMappingMatchInfo(match_dict, self._routes[method]), allowed_methods) def __len__(self) -> int: return len(self._routes) def __iter__(self) -> Iterator[AbstractRoute]: return iter(self._routes.values()) async def _handle(self, request: Request) -> StreamResponse: filename = request.match_info["filename"] if Path(filename).is_absolute(): # filename is an absolute path e.g. //network/share or D:\path # which could be a UNC path leading to NTLM credential theft raise HTTPNotFound() unresolved_path = self._directory.joinpath(filename) loop = asyncio.get_running_loop() return await loop.run_in_executor( None, self._resolve_path_to_response, unresolved_path ) def _resolve_path_to_response(self, unresolved_path: Path) -> StreamResponse: """Take the unresolved path and query the file system to form a response.""" # Check for access outside the root directory. For follow symlinks, URI # cannot traverse out, but symlinks can. Otherwise, no access outside # root is permitted. try: if self._follow_symlinks: normalized_path = Path(os.path.normpath(unresolved_path)) normalized_path.relative_to(self._directory) file_path = normalized_path.resolve() else: file_path = unresolved_path.resolve() file_path.relative_to(self._directory) except (ValueError, *CIRCULAR_SYMLINK_ERROR) as error: # ValueError is raised for the relative check. Circular symlinks # raise here on resolving for python < 3.13. raise HTTPNotFound() from error # if path is a directory, return the contents if permitted. Note the # directory check will raise if a segment is not readable. try: if file_path.is_dir(): if self._show_index: return Response( text=self._directory_as_html(file_path), content_type="text/html", ) else: raise HTTPForbidden() except PermissionError as error: raise HTTPForbidden() from error # Return the file response, which handles all other checks. return FileResponse(file_path, chunk_size=self._chunk_size) def _directory_as_html(self, dir_path: Path) -> str: """returns directory's index as html.""" assert dir_path.is_dir() relative_path_to_dir = dir_path.relative_to(self._directory).as_posix() index_of = f"Index of /{html_escape(relative_path_to_dir)}" h1 = f"

{index_of}

" index_list = [] dir_index = dir_path.iterdir() for _file in sorted(dir_index): # show file url as relative to static path rel_path = _file.relative_to(self._directory).as_posix() quoted_file_url = _quote_path(f"{self._prefix}/{rel_path}") # if file is a directory, add '/' to the end of the name if _file.is_dir(): file_name = f"{_file.name}/" else: file_name = _file.name index_list.append( f'
  • {html_escape(file_name)}
  • ' ) ul = "
      \n{}\n
    ".format("\n".join(index_list)) body = f"\n{h1}\n{ul}\n" head_str = f"\n{index_of}\n" html = f"\n{head_str}\n{body}\n" return html def __repr__(self) -> str: name = "'" + self.name + "'" if self.name is not None else "" return f" {self._directory!r}>" class PrefixedSubAppResource(PrefixResource): def __init__(self, prefix: str, app: "Application") -> None: super().__init__(prefix) self._app = app self._add_prefix_to_resources(prefix) def add_prefix(self, prefix: str) -> None: super().add_prefix(prefix) self._add_prefix_to_resources(prefix) def _add_prefix_to_resources(self, prefix: str) -> None: router = self._app.router for resource in router.resources(): # Since the canonical path of a resource is about # to change, we need to unindex it and then reindex router.unindex_resource(resource) resource.add_prefix(prefix) router.index_resource(resource) def url_for(self, *args: str, **kwargs: str) -> URL: raise RuntimeError(".url_for() is not supported by sub-application root") def get_info(self) -> _InfoDict: return {"app": self._app, "prefix": self._prefix} async def resolve(self, request: Request) -> _Resolve: match_info = await self._app.router.resolve(request) match_info.add_app(self._app) if isinstance(match_info.http_exception, HTTPMethodNotAllowed): methods = match_info.http_exception.allowed_methods else: methods = set() return match_info, methods def __len__(self) -> int: return len(self._app.router.routes()) def __iter__(self) -> Iterator[AbstractRoute]: return iter(self._app.router.routes()) def __repr__(self) -> str: return f" {self._app!r}>" class AbstractRuleMatching(abc.ABC): @abc.abstractmethod # pragma: no branch async def match(self, request: Request) -> bool: """Return bool if the request satisfies the criteria""" @abc.abstractmethod # pragma: no branch def get_info(self) -> _InfoDict: """Return a dict with additional info useful for introspection""" @property @abc.abstractmethod # pragma: no branch def canonical(self) -> str: """Return a str""" class Domain(AbstractRuleMatching): re_part = re.compile(r"(?!-)[a-z\d-]{1,63}(? None: super().__init__() self._domain = self.validation(domain) @property def canonical(self) -> str: return self._domain def validation(self, domain: str) -> str: if not isinstance(domain, str): raise TypeError("Domain must be str") domain = domain.rstrip(".").lower() if not domain: raise ValueError("Domain cannot be empty") elif "://" in domain: raise ValueError("Scheme not supported") url = URL("http://" + domain) assert url.raw_host is not None if not all(self.re_part.fullmatch(x) for x in url.raw_host.split(".")): raise ValueError("Domain not valid") if url.port == 80: return url.raw_host return f"{url.raw_host}:{url.port}" async def match(self, request: Request) -> bool: host = request.headers.get(hdrs.HOST) if not host: return False return self.match_domain(host) def match_domain(self, host: str) -> bool: return host.lower() == self._domain def get_info(self) -> _InfoDict: return {"domain": self._domain} class MaskDomain(Domain): re_part = re.compile(r"(?!-)[a-z\d\*-]{1,63}(? None: super().__init__(domain) mask = self._domain.replace(".", r"\.").replace("*", ".*") self._mask = re.compile(mask) @property def canonical(self) -> str: return self._mask.pattern def match_domain(self, host: str) -> bool: return self._mask.fullmatch(host) is not None class MatchedSubAppResource(PrefixedSubAppResource): def __init__(self, rule: AbstractRuleMatching, app: "Application") -> None: AbstractResource.__init__(self) self._prefix = "" self._app = app self._rule = rule @property def canonical(self) -> str: return self._rule.canonical def get_info(self) -> _InfoDict: return {"app": self._app, "rule": self._rule} async def resolve(self, request: Request) -> _Resolve: if not await self._rule.match(request): return None, set() match_info = await self._app.router.resolve(request) match_info.add_app(self._app) if isinstance(match_info.http_exception, HTTPMethodNotAllowed): methods = match_info.http_exception.allowed_methods else: methods = set() return match_info, methods def __repr__(self) -> str: return f" {self._app!r}>" class ResourceRoute(AbstractRoute): """A route with resource""" def __init__( self, method: str, handler: Handler | type[AbstractView], resource: AbstractResource, *, expect_handler: _ExpectHandler | None = None, ) -> None: super().__init__( method, handler, expect_handler=expect_handler, resource=resource ) def __repr__(self) -> str: return f" {self.handler!r}" @property def name(self) -> str | None: if self._resource is None: return None return self._resource.name def url_for(self, *args: str, **kwargs: str) -> URL: """Construct url for route with additional params.""" assert self._resource is not None return self._resource.url_for(*args, **kwargs) def get_info(self) -> _InfoDict: assert self._resource is not None return self._resource.get_info() class SystemRoute(AbstractRoute): def __init__(self, http_exception: HTTPException) -> None: super().__init__(hdrs.METH_ANY, self._handle) self._http_exception = http_exception def url_for(self, *args: str, **kwargs: str) -> URL: raise RuntimeError(".url_for() is not allowed for SystemRoute") @property def name(self) -> str | None: return None def get_info(self) -> _InfoDict: return {"http_exception": self._http_exception} async def _handle(self, request: Request) -> StreamResponse: raise self._http_exception @property def status(self) -> int: return self._http_exception.status @property def reason(self) -> str: return self._http_exception.reason def __repr__(self) -> str: return f"" class View(AbstractView): async def _iter(self) -> StreamResponse: if self.request.method not in hdrs.METH_ALL: self._raise_allowed_methods() method: Callable[[], Awaitable[StreamResponse]] | None = getattr( self, self.request.method.lower(), None ) if method is None: self._raise_allowed_methods() return await method() def __await__(self) -> Generator[None, None, StreamResponse]: return self._iter().__await__() def _raise_allowed_methods(self) -> NoReturn: allowed_methods = {m for m in hdrs.METH_ALL if hasattr(self, m.lower())} raise HTTPMethodNotAllowed(self.request.method, allowed_methods) class ResourcesView(Sized, Iterable[AbstractResource], Container[AbstractResource]): def __init__(self, resources: list[AbstractResource]) -> None: self._resources = resources def __len__(self) -> int: return len(self._resources) def __iter__(self) -> Iterator[AbstractResource]: yield from self._resources def __contains__(self, resource: object) -> bool: return resource in self._resources class RoutesView(Sized, Iterable[AbstractRoute], Container[AbstractRoute]): def __init__(self, resources: list[AbstractResource]): self._routes: list[AbstractRoute] = [] for resource in resources: for route in resource: self._routes.append(route) def __len__(self) -> int: return len(self._routes) def __iter__(self) -> Iterator[AbstractRoute]: yield from self._routes def __contains__(self, route: object) -> bool: return route in self._routes class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): NAME_SPLIT_RE = re.compile(r"[.:-]") HTTP_NOT_FOUND = HTTPNotFound() def __init__(self) -> None: super().__init__() self._resources: list[AbstractResource] = [] self._named_resources: dict[str, AbstractResource] = {} self._resource_index: dict[str, list[AbstractResource]] = {} self._matched_sub_app_resources: list[MatchedSubAppResource] = [] async def resolve(self, request: Request) -> UrlMappingMatchInfo: resource_index = self._resource_index allowed_methods: set[str] = set() # MatchedSubAppResource is primarily used to match on domain names # (though custom rules could match on other things). This means that # the traversal algorithm below can't be applied, and that we likely # need to check these first so a sub app that defines the same path # as a parent app will get priority if there's a domain match. # # For most cases we do not expect there to be many of these since # currently they are only added by `.add_domain()`. for resource in self._matched_sub_app_resources: match_dict, allowed = await resource.resolve(request) if match_dict is not None: return match_dict else: allowed_methods |= allowed # Walk the url parts looking for candidates. We walk the url backwards # to ensure the most explicit match is found first. If there are multiple # candidates for a given url part because there are multiple resources # registered for the same canonical path, we resolve them in a linear # fashion to ensure registration order is respected. url_part = request.rel_url.path_safe while url_part: for candidate in resource_index.get(url_part, ()): match_dict, allowed = await candidate.resolve(request) if match_dict is not None: return match_dict else: allowed_methods |= allowed if url_part == "/": break url_part = url_part.rpartition("/")[0] or "/" if allowed_methods: return MatchInfoError(HTTPMethodNotAllowed(request.method, allowed_methods)) return MatchInfoError(self.HTTP_NOT_FOUND) def __iter__(self) -> Iterator[str]: return iter(self._named_resources) def __len__(self) -> int: return len(self._named_resources) def __contains__(self, resource: object) -> bool: return resource in self._named_resources def __getitem__(self, name: str) -> AbstractResource: return self._named_resources[name] def resources(self) -> ResourcesView: return ResourcesView(self._resources) def routes(self) -> RoutesView: return RoutesView(self._resources) def named_resources(self) -> Mapping[str, AbstractResource]: return MappingProxyType(self._named_resources) def register_resource(self, resource: AbstractResource) -> None: assert isinstance( resource, AbstractResource ), f"Instance of AbstractResource class is required, got {resource!r}" if self.frozen: raise RuntimeError("Cannot register a resource into frozen router.") name = resource.name if name is not None: parts = self.NAME_SPLIT_RE.split(name) for part in parts: if keyword.iskeyword(part): raise ValueError( f"Incorrect route name {name!r}, " "python keywords cannot be used " "for route name" ) if not part.isidentifier(): raise ValueError( f"Incorrect route name {name!r}, " "the name should be a sequence of " "python identifiers separated " "by dash, dot or column" ) if name in self._named_resources: raise ValueError( f"Duplicate {name!r}, " f"already handled by {self._named_resources[name]!r}" ) self._named_resources[name] = resource self._resources.append(resource) if isinstance(resource, MatchedSubAppResource): # We cannot index match sub-app resources because they have match rules self._matched_sub_app_resources.append(resource) else: self.index_resource(resource) def _get_resource_index_key(self, resource: AbstractResource) -> str: """Return a key to index the resource in the resource index.""" if "{" in (index_key := resource.canonical): # strip at the first { to allow for variables, and than # rpartition at / to allow for variable parts in the path # For example if the canonical path is `/core/locations{tail:.*}` # the index key will be `/core` since index is based on the # url parts split by `/` index_key = index_key.partition("{")[0].rpartition("/")[0] return index_key.rstrip("/") or "/" def index_resource(self, resource: AbstractResource) -> None: """Add a resource to the resource index.""" resource_key = self._get_resource_index_key(resource) # There may be multiple resources for a canonical path # so we keep them in a list to ensure that registration # order is respected. self._resource_index.setdefault(resource_key, []).append(resource) def unindex_resource(self, resource: AbstractResource) -> None: """Remove a resource from the resource index.""" resource_key = self._get_resource_index_key(resource) self._resource_index[resource_key].remove(resource) def add_resource(self, path: str, *, name: str | None = None) -> Resource: if path and not path.startswith("/"): raise ValueError("path should be started with / or be empty") # Reuse last added resource if path and name are the same if self._resources: resource = self._resources[-1] if resource.name == name and resource.raw_match(path): return cast(Resource, resource) if not ("{" in path or "}" in path or ROUTE_RE.search(path)): resource = PlainResource(path, name=name) self.register_resource(resource) return resource resource = DynamicResource(path, name=name) self.register_resource(resource) return resource def add_route( self, method: str, path: str, handler: Handler | type[AbstractView], *, name: str | None = None, expect_handler: _ExpectHandler | None = None, ) -> AbstractRoute: resource = self.add_resource(path, name=name) return resource.add_route(method, handler, expect_handler=expect_handler) def add_static( self, prefix: str, path: PathLike, *, name: str | None = None, expect_handler: _ExpectHandler | None = None, chunk_size: int = 256 * 1024, show_index: bool = False, follow_symlinks: bool = False, append_version: bool = False, ) -> StaticResource: """Add static files view. prefix - url prefix path - folder with files """ assert prefix.startswith("/") if prefix.endswith("/"): prefix = prefix[:-1] resource = StaticResource( prefix, path, name=name, expect_handler=expect_handler, chunk_size=chunk_size, show_index=show_index, follow_symlinks=follow_symlinks, append_version=append_version, ) self.register_resource(resource) return resource def add_head(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: """Shortcut for add_route with method HEAD.""" return self.add_route(hdrs.METH_HEAD, path, handler, **kwargs) def add_options(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: """Shortcut for add_route with method OPTIONS.""" return self.add_route(hdrs.METH_OPTIONS, path, handler, **kwargs) def add_get( self, path: str, handler: Handler, *, name: str | None = None, allow_head: bool = True, **kwargs: Any, ) -> AbstractRoute: """Shortcut for add_route with method GET. If allow_head is true, another route is added allowing head requests to the same endpoint. """ resource = self.add_resource(path, name=name) if allow_head: resource.add_route(hdrs.METH_HEAD, handler, **kwargs) return resource.add_route(hdrs.METH_GET, handler, **kwargs) def add_post(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: """Shortcut for add_route with method POST.""" return self.add_route(hdrs.METH_POST, path, handler, **kwargs) def add_put(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: """Shortcut for add_route with method PUT.""" return self.add_route(hdrs.METH_PUT, path, handler, **kwargs) def add_patch(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: """Shortcut for add_route with method PATCH.""" return self.add_route(hdrs.METH_PATCH, path, handler, **kwargs) def add_delete(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: """Shortcut for add_route with method DELETE.""" return self.add_route(hdrs.METH_DELETE, path, handler, **kwargs) def add_view( self, path: str, handler: type[AbstractView], **kwargs: Any ) -> AbstractRoute: """Shortcut for add_route with ANY methods for a class-based view.""" return self.add_route(hdrs.METH_ANY, path, handler, **kwargs) def freeze(self) -> None: super().freeze() for resource in self._resources: resource.freeze() def add_routes(self, routes: Iterable[AbstractRouteDef]) -> list[AbstractRoute]: """Append routes to route table. Parameter should be a sequence of RouteDef objects. Returns a list of registered AbstractRoute instances. """ registered_routes = [] for route_def in routes: registered_routes.extend(route_def.register(self)) return registered_routes def _quote_path(value: str) -> str: return URL.build(path=value, encoded=False).raw_path def _unquote_path_safe(value: str) -> str: if "%" not in value: return value return value.replace("%2F", "/").replace("%25", "%") def _requote_path(value: str) -> str: # Quote non-ascii characters and other characters which must be quoted, # but preserve existing %-sequences. result = _quote_path(value) if "%" in value: result = result.replace("%25", "%") return result ================================================ FILE: aiohttp/web_ws.py ================================================ import asyncio import base64 import binascii import hashlib import json import sys from collections.abc import Callable, Iterable from typing import Any, Final, Generic, Literal, Union, overload from multidict import CIMultiDict from . import hdrs from ._websocket.reader import WebSocketDataQueue from ._websocket.writer import DEFAULT_LIMIT from .abc import AbstractStreamWriter from .client_exceptions import WSMessageTypeError from .helpers import ( calculate_timeout_when, frozen_dataclass_decorator, set_exception, set_result, ) from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, WS_KEY, WebSocketError, WebSocketReader, WebSocketWriter, WSCloseCode, WSMessageDecodeText, WSMessageNoDecodeText, WSMsgType, ws_ext_gen, ws_ext_parse, ) from .http_websocket import _INTERNAL_RECEIVE_TYPES, WSMessageError from .log import ws_logger from .streams import EofStream from .typedefs import JSONBytesEncoder, JSONDecoder, JSONEncoder from .web_exceptions import HTTPBadRequest, HTTPException from .web_request import BaseRequest from .web_response import StreamResponse if sys.version_info >= (3, 13): from typing import TypeVar else: from typing_extensions import TypeVar if sys.version_info >= (3, 11): import asyncio as async_timeout from typing import Self else: import async_timeout from typing_extensions import Self __all__ = ( "WebSocketResponse", "WebSocketReady", "WSMsgType", ) THRESHOLD_CONNLOST_ACCESS: Final[int] = 5 # TypeVar for whether text messages are decoded to str (True) or kept as bytes (False) _DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True]) @frozen_dataclass_decorator class WebSocketReady: ok: bool protocol: str | None def __bool__(self) -> bool: return self.ok class WebSocketResponse(StreamResponse, Generic[_DecodeText]): _length_check: bool = False _ws_protocol: str | None = None _writer: WebSocketWriter | None = None _reader: WebSocketDataQueue | None = None _closed: bool = False _closing: bool = False _conn_lost: int = 0 _close_code: int | None = None _loop: asyncio.AbstractEventLoop | None = None _waiting: bool = False _close_wait: asyncio.Future[None] | None = None _exception: BaseException | None = None _heartbeat_when: float = 0.0 _heartbeat_cb: asyncio.TimerHandle | None = None _pong_response_cb: asyncio.TimerHandle | None = None _ping_task: asyncio.Task[None] | None = None _need_heartbeat_reset: bool = False _heartbeat_reset_handle: asyncio.Handle | None = None def __init__( self, *, timeout: float = 10.0, receive_timeout: float | None = None, autoclose: bool = True, autoping: bool = True, heartbeat: float | None = None, protocols: Iterable[str] = (), compress: bool = True, max_msg_size: int = 4 * 1024 * 1024, writer_limit: int = DEFAULT_LIMIT, decode_text: bool = True, ) -> None: super().__init__(status=101) self._protocols = protocols self._timeout = timeout self._receive_timeout = receive_timeout self._autoclose = autoclose self._autoping = autoping self._heartbeat = heartbeat if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 self._compress: bool | int = compress self._max_msg_size = max_msg_size self._writer_limit = writer_limit self._decode_text = decode_text self._need_heartbeat_reset = False self._heartbeat_reset_handle = None def _cancel_heartbeat(self) -> None: self._cancel_pong_response_cb() if self._heartbeat_reset_handle is not None: self._heartbeat_reset_handle.cancel() self._heartbeat_reset_handle = None self._need_heartbeat_reset = False if self._heartbeat_cb is not None: self._heartbeat_cb.cancel() self._heartbeat_cb = None if self._ping_task is not None: self._ping_task.cancel() self._ping_task = None def _cancel_pong_response_cb(self) -> None: if self._pong_response_cb is not None: self._pong_response_cb.cancel() self._pong_response_cb = None def _on_data_received(self) -> None: if self._heartbeat is None or self._need_heartbeat_reset: return loop = self._loop assert loop is not None # Coalesce multiple chunks received in the same loop tick into a single # heartbeat reset. Resetting immediately per chunk increases timer churn. self._need_heartbeat_reset = True self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset) def _flush_heartbeat_reset(self) -> None: self._heartbeat_reset_handle = None if not self._need_heartbeat_reset: return self._reset_heartbeat() self._need_heartbeat_reset = False def _reset_heartbeat(self) -> None: if self._heartbeat is None: return self._cancel_pong_response_cb() req = self._req timeout_ceil_threshold = ( req._protocol._timeout_ceil_threshold if req is not None else 5 ) loop = self._loop assert loop is not None now = loop.time() when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold) self._heartbeat_when = when if self._heartbeat_cb is None: # We do not cancel the previous heartbeat_cb here because # it generates a significant amount of TimerHandle churn # which causes asyncio to rebuild the heap frequently. # Instead _send_heartbeat() will reschedule the next # heartbeat if it fires too early. self._heartbeat_cb = loop.call_at(when, self._send_heartbeat) def _send_heartbeat(self) -> None: self._heartbeat_cb = None # If heartbeat reset is pending (data is being received), skip sending # the ping and let the reset callback handle rescheduling the heartbeat. if self._need_heartbeat_reset: return loop = self._loop assert loop is not None and self._writer is not None now = loop.time() if now < self._heartbeat_when: # Heartbeat fired too early, reschedule self._heartbeat_cb = loop.call_at( self._heartbeat_when, self._send_heartbeat ) return req = self._req timeout_ceil_threshold = ( req._protocol._timeout_ceil_threshold if req is not None else 5 ) when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold) self._cancel_pong_response_cb() self._pong_response_cb = loop.call_at(when, self._pong_not_received) coro = self._writer.send_frame(b"", WSMsgType.PING) if sys.version_info >= (3, 12): # Optimization for Python 3.12, try to send the ping # immediately to avoid having to schedule # the task on the event loop. ping_task = asyncio.Task(coro, loop=loop, eager_start=True) else: ping_task = loop.create_task(coro) if not ping_task.done(): self._ping_task = ping_task ping_task.add_done_callback(self._ping_task_done) else: self._ping_task_done(ping_task) def _ping_task_done(self, task: "asyncio.Task[None]") -> None: """Callback for when the ping task completes.""" if not task.cancelled() and (exc := task.exception()): self._handle_ping_pong_exception(exc) self._ping_task = None def _pong_not_received(self) -> None: if self._req is not None and self._req.transport is not None: self._handle_ping_pong_exception( asyncio.TimeoutError( f"No PONG received after {self._pong_heartbeat} seconds" ) ) def _handle_ping_pong_exception(self, exc: BaseException) -> None: """Handle exceptions raised during ping/pong processing.""" if self._closed: return self._set_closed() self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) self._exception = exc if self._waiting and not self._closing and self._reader is not None: self._reader.feed_data(WSMessageError(data=exc, extra=None)) def _set_closed(self) -> None: """Set the connection to closed. Cancel any heartbeat timers and set the closed flag. """ self._closed = True self._cancel_heartbeat() async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: # make pre-check to don't hide it by do_handshake() exceptions if self._payload_writer is not None: return self._payload_writer protocol, writer = self._pre_start(request) payload_writer = await super().prepare(request) assert payload_writer is not None self._post_start(request, protocol, writer) await payload_writer.drain() return payload_writer def _handshake( self, request: BaseRequest ) -> tuple["CIMultiDict[str]", str | None, int, bool]: headers = request.headers if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip(): raise HTTPBadRequest( text=( f"No WebSocket UPGRADE hdr: {headers.get(hdrs.UPGRADE)}\n Can " '"Upgrade" only to "WebSocket".' ) ) if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower(): raise HTTPBadRequest( text=f"No CONNECTION upgrade hdr: {headers.get(hdrs.CONNECTION)}" ) # find common sub-protocol between client and server protocol: str | None = None if hdrs.SEC_WEBSOCKET_PROTOCOL in headers: req_protocols = [ str(proto.strip()) for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") ] for proto in req_protocols: if proto in self._protocols: protocol = proto break else: # No overlap found: Return no protocol as per spec ws_logger.warning( "%s: Client protocols %r don’t overlap server-known ones %r", request.remote, req_protocols, self._protocols, ) # check supported version version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "") if version not in ("13", "8", "7"): raise HTTPBadRequest(text=f"Unsupported version: {version}") # check client handshake for validity key = headers.get(hdrs.SEC_WEBSOCKET_KEY) try: if not key or len(base64.b64decode(key)) != 16: raise HTTPBadRequest(text=f"Handshake error: {key!r}") except binascii.Error: raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None accept_val = base64.b64encode( hashlib.sha1(key.encode() + WS_KEY).digest() ).decode() response_headers = CIMultiDict( { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: accept_val, } ) notakeover = False compress = 0 if self._compress: extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) # Server side always get return with no exception. # If something happened, just drop compress extension compress, notakeover = ws_ext_parse(extensions, isserver=True) if compress: enabledext = ws_ext_gen( compress=compress, isserver=True, server_notakeover=notakeover ) response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext if protocol: response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol return ( response_headers, protocol, compress, notakeover, ) def _pre_start(self, request: BaseRequest) -> tuple[str | None, WebSocketWriter]: self._loop = request._loop headers, protocol, compress, notakeover = self._handshake(request) self.set_status(101) self.headers.update(headers) self.force_close() self._compress = compress transport = request._protocol.transport assert transport is not None writer = WebSocketWriter( request._protocol, transport, compress=compress, notakeover=notakeover, limit=self._writer_limit, ) return protocol, writer def _post_start( self, request: BaseRequest, protocol: str | None, writer: WebSocketWriter ) -> None: self._ws_protocol = protocol self._writer = writer self._reset_heartbeat() loop = self._loop assert loop is not None self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop) parser = WebSocketReader( self._reader, self._max_msg_size, compress=bool(self._compress), decode_text=self._decode_text, ) cb = None if self._heartbeat is None else self._on_data_received request.protocol.set_parser(parser, data_received_cb=cb) # disable HTTP keepalive for WebSocket request.protocol.keep_alive(False) def can_prepare(self, request: BaseRequest) -> WebSocketReady: if self._writer is not None: raise RuntimeError("Already started") try: _, protocol, _, _ = self._handshake(request) except HTTPException: return WebSocketReady(False, None) else: return WebSocketReady(True, protocol) @property def prepared(self) -> bool: return self._writer is not None @property def closed(self) -> bool: return self._closed @property def close_code(self) -> int | None: return self._close_code @property def ws_protocol(self) -> str | None: return self._ws_protocol @property def compress(self) -> int | bool: return self._compress def get_extra_info(self, name: str, default: Any = None) -> Any: """Get optional transport information. If no value associated with ``name`` is found, ``default`` is returned. """ writer = self._writer if writer is None: return default return writer.transport.get_extra_info(name, default) def exception(self) -> BaseException | None: return self._exception async def ping(self, message: bytes = b"") -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") await self._writer.send_frame(message, WSMsgType.PING) async def pong(self, message: bytes = b"") -> None: # unsolicited pong if self._writer is None: raise RuntimeError("Call .prepare() first") await self._writer.send_frame(message, WSMsgType.PONG) async def send_frame( self, message: bytes, opcode: WSMsgType, compress: int | None = None ) -> None: """Send a frame over the websocket.""" if self._writer is None: raise RuntimeError("Call .prepare() first") await self._writer.send_frame(message, opcode, compress) async def send_str(self, data: str, compress: int | None = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, str): raise TypeError("data argument must be str (%r)" % type(data)) await self._writer.send_frame( data.encode("utf-8"), WSMsgType.TEXT, compress=compress ) async def send_bytes(self, data: bytes, compress: int | None = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError("data argument must be byte-ish (%r)" % type(data)) await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) async def send_json( self, data: Any, compress: int | None = None, *, dumps: JSONEncoder = json.dumps, ) -> None: await self.send_str(dumps(data), compress=compress) async def send_json_bytes( self, data: Any, compress: int | None = None, *, dumps: JSONBytesEncoder, ) -> None: """Send JSON data using a bytes-returning encoder as a binary frame. Use this when your JSON encoder (like orjson) returns bytes instead of str, avoiding the encode/decode overhead. """ await self.send_bytes(dumps(data), compress=compress) async def write_eof(self) -> None: # type: ignore[override] if self._eof_sent: return if self._payload_writer is None: raise RuntimeError("Response has not been started") await self.close() self._eof_sent = True async def close( self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True ) -> bool: """Close websocket connection.""" if self._writer is None: raise RuntimeError("Call .prepare() first") if self._closed: return False self._set_closed() try: await self._writer.close(code, message) writer = self._payload_writer assert writer is not None if drain: await writer.drain() except (asyncio.CancelledError, asyncio.TimeoutError): self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) raise except Exception as exc: self._exception = exc self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) return True reader = self._reader assert reader is not None # we need to break `receive()` cycle before we can call # `reader.read()` as `close()` may be called from different task if self._waiting: assert self._loop is not None assert self._close_wait is None self._close_wait = self._loop.create_future() reader.feed_data(WS_CLOSING_MESSAGE) await self._close_wait if self._closing: self._close_transport() return True try: async with async_timeout.timeout(self._timeout): while True: msg = await reader.read() if msg.type is WSMsgType.CLOSE: self._set_code_close_transport(msg.data) return True except asyncio.CancelledError: self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) raise except Exception as exc: self._exception = exc self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) return True def _set_closing(self, code: int) -> None: """Set the close code and mark the connection as closing.""" self._closing = True self._close_code = code self._cancel_heartbeat() def _set_code_close_transport(self, code: int) -> None: """Set the close code and close the transport.""" self._close_code = code self._close_transport() def _close_transport(self) -> None: """Close the transport.""" if self._req is not None and self._req.transport is not None: self._req.transport.close() @overload async def receive( self: "WebSocketResponse[Literal[True]]", timeout: float | None = None ) -> WSMessageDecodeText: ... @overload async def receive( self: "WebSocketResponse[Literal[False]]", timeout: float | None = None ) -> WSMessageNoDecodeText: ... @overload async def receive( self: "WebSocketResponse[_DecodeText]", timeout: float | None = None ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... async def receive( self, timeout: float | None = None ) -> WSMessageDecodeText | WSMessageNoDecodeText: if self._reader is None: raise RuntimeError("Call .prepare() first") receive_timeout = timeout or self._receive_timeout while True: if self._waiting: raise RuntimeError("Concurrent call to receive() is not allowed") if self._closed: self._conn_lost += 1 if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS: raise RuntimeError("WebSocket connection is closed.") return WS_CLOSED_MESSAGE elif self._closing: return WS_CLOSING_MESSAGE try: self._waiting = True try: if receive_timeout: # Entering the context manager and creating # Timeout() object can take almost 50% of the # run time in this loop so we avoid it if # there is no read timeout. async with async_timeout.timeout(receive_timeout): msg = await self._reader.read() else: msg = await self._reader.read() finally: self._waiting = False if self._close_wait: set_result(self._close_wait, None) except asyncio.TimeoutError: raise except EofStream: self._close_code = WSCloseCode.OK await self.close() return WS_CLOSED_MESSAGE except WebSocketError as exc: self._close_code = exc.code await self.close(code=exc.code) return WSMessageError(data=exc) except Exception as exc: self._exception = exc self._set_closing(WSCloseCode.ABNORMAL_CLOSURE) await self.close() return WSMessageError(data=exc) if msg.type not in _INTERNAL_RECEIVE_TYPES: # If its not a close/closing/ping/pong message # we can return it immediately return msg if msg.type is WSMsgType.CLOSE: self._set_closing(msg.data) # Could be closed while awaiting reader. if not self._closed and self._autoclose: # type: ignore[redundant-expr] # The client is likely going to close the # connection out from under us so we do not # want to drain any pending writes as it will # likely result writing to a broken pipe. await self.close(drain=False) elif msg.type is WSMsgType.CLOSING: self._set_closing(WSCloseCode.OK) elif msg.type is WSMsgType.PING and self._autoping: await self.pong(msg.data) continue elif msg.type is WSMsgType.PONG and self._autoping: continue return msg @overload async def receive_str( self: "WebSocketResponse[Literal[True]]", *, timeout: float | None = None ) -> str: ... @overload async def receive_str( self: "WebSocketResponse[Literal[False]]", *, timeout: float | None = None ) -> bytes: ... @overload async def receive_str( self: "WebSocketResponse[_DecodeText]", *, timeout: float | None = None ) -> str | bytes: ... async def receive_str(self, *, timeout: float | None = None) -> str | bytes: """Receive TEXT message. Returns str when decode_text=True (default), bytes when decode_text=False. """ msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" ) return msg.data async def receive_bytes(self, *, timeout: float | None = None) -> bytes: msg = await self.receive(timeout) if msg.type is not WSMsgType.BINARY: raise WSMessageTypeError( f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY" ) return msg.data @overload async def receive_json( self: "WebSocketResponse[Literal[True]]", *, loads: JSONDecoder = ..., timeout: float | None = None, ) -> Any: ... @overload async def receive_json( self: "WebSocketResponse[Literal[False]]", *, loads: Callable[[bytes], Any] = ..., timeout: float | None = None, ) -> Any: ... @overload async def receive_json( self: "WebSocketResponse[_DecodeText]", *, loads: JSONDecoder | Callable[[bytes], Any] = ..., timeout: float | None = None, ) -> Any: ... async def receive_json( self, *, loads: JSONDecoder | Callable[[bytes], Any] = json.loads, timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) return loads(data) # type: ignore[arg-type] async def write( self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] ) -> None: raise RuntimeError("Cannot call .write() for websocket") def __aiter__(self) -> Self: return self @overload async def __anext__( self: "WebSocketResponse[Literal[True]]", ) -> WSMessageDecodeText: ... @overload async def __anext__( self: "WebSocketResponse[Literal[False]]", ) -> WSMessageNoDecodeText: ... @overload async def __anext__( self: "WebSocketResponse[_DecodeText]", ) -> WSMessageDecodeText | WSMessageNoDecodeText: ... async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText: msg = await self.receive() if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): raise StopAsyncIteration return msg def _cancel(self, exc: BaseException) -> None: # web_protocol calls this from connection_lost # or when the server is shutting down. self._closing = True self._cancel_heartbeat() if self._reader is not None: set_exception(self._reader, exc) ================================================ FILE: aiohttp/worker.py ================================================ """Async gunicorn worker for aiohttp.web""" import asyncio import inspect import os import re import signal import sys from types import FrameType from typing import Any, Optional from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat from gunicorn.workers import base from aiohttp import web from .helpers import set_result from .web_app import Application from .web_log import AccessLogger try: import ssl SSLContext = ssl.SSLContext except ImportError: # pragma: no cover ssl = None # type: ignore[assignment] SSLContext = object # type: ignore[misc,assignment] __all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker") class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported] DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default def __init__(self, *args: Any, **kw: Any) -> None: super().__init__(*args, **kw) self._task: asyncio.Task[None] | None = None self.exit_code = 0 self._notify_waiter: asyncio.Future[bool] | None = None def init_process(self) -> None: # create new event_loop after fork self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) super().init_process() def run(self) -> None: self._task = self.loop.create_task(self._run()) try: # ignore all finalization problems self.loop.run_until_complete(self._task) except Exception: self.log.exception("Exception in gunicorn worker") self.loop.run_until_complete(self.loop.shutdown_asyncgens()) self.loop.close() sys.exit(self.exit_code) async def _run(self) -> None: runner = None if isinstance(self.wsgi, Application): app = self.wsgi elif inspect.iscoroutinefunction(self.wsgi) or ( sys.version_info < (3, 14) and asyncio.iscoroutinefunction(self.wsgi) ): wsgi = await self.wsgi() if isinstance(wsgi, web.AppRunner): runner = wsgi app = runner.app else: app = wsgi else: raise RuntimeError( "wsgi app should be either Application or " f"async function returning Application, got {self.wsgi}" ) if runner is None: access_log = self.log.access_log if self.cfg.accesslog else None runner = web.AppRunner( app, logger=self.log, keepalive_timeout=self.cfg.keepalive, access_log=access_log, access_log_format=self._get_valid_log_format( self.cfg.access_log_format ), shutdown_timeout=self.cfg.graceful_timeout / 100 * 95, ) await runner.setup() ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None assert runner is not None server = runner.server assert server is not None for sock in self.sockets: site = web.SockSite( runner, sock, ssl_context=ctx, ) await site.start() # If our parent changed then we shut down. pid = os.getpid() try: while self.alive: # type: ignore[has-type] self.notify() cnt = server.requests_count if self.max_requests and cnt > self.max_requests: self.alive = False self.log.info("Max requests, shutting down: %s", self) elif pid == os.getpid() and self.ppid != os.getppid(): self.alive = False self.log.info("Parent changed, shutting down: %s", self) else: await self._wait_next_notify() except Exception: pass await runner.cleanup() def _wait_next_notify(self) -> "asyncio.Future[bool]": self._notify_waiter_done() loop = self.loop assert loop is not None self._notify_waiter = waiter = loop.create_future() self.loop.call_later(1.0, self._notify_waiter_done, waiter) return waiter def _notify_waiter_done( self, waiter: Optional["asyncio.Future[bool]"] = None ) -> None: if waiter is None: waiter = self._notify_waiter if waiter is not None: set_result(waiter, True) if waiter is self._notify_waiter: self._notify_waiter = None def init_signals(self) -> None: # Set up signals through the event loop API. self.loop.add_signal_handler( signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None ) self.loop.add_signal_handler( signal.SIGTERM, self.handle_exit, signal.SIGTERM, None ) self.loop.add_signal_handler( signal.SIGINT, self.handle_quit, signal.SIGINT, None ) self.loop.add_signal_handler( signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None ) self.loop.add_signal_handler( signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None ) self.loop.add_signal_handler( signal.SIGABRT, self.handle_abort, signal.SIGABRT, None ) # Don't let SIGTERM and SIGUSR1 disturb active requests # by interrupting system calls signal.siginterrupt(signal.SIGTERM, False) signal.siginterrupt(signal.SIGUSR1, False) # Reset signals so Gunicorn doesn't swallow subprocess return codes # See: https://github.com/aio-libs/aiohttp/issues/6130 def handle_quit(self, sig: int, frame: FrameType | None) -> None: self.alive = False # worker_int callback self.cfg.worker_int(self) # wakeup closing process self._notify_waiter_done() def handle_abort(self, sig: int, frame: FrameType | None) -> None: self.alive = False self.exit_code = 1 self.cfg.worker_abort(self) sys.exit(1) @staticmethod def _create_ssl_context(cfg: Any) -> "SSLContext": """Creates SSLContext instance for usage in asyncio.create_server. See ssl.SSLSocket.__init__ for more details. """ if ssl is None: # pragma: no cover raise RuntimeError("SSL is not supported.") ctx = ssl.SSLContext(cfg.ssl_version) ctx.load_cert_chain(cfg.certfile, cfg.keyfile) ctx.verify_mode = cfg.cert_reqs if cfg.ca_certs: ctx.load_verify_locations(cfg.ca_certs) if cfg.ciphers: ctx.set_ciphers(cfg.ciphers) return ctx def _get_valid_log_format(self, source_format: str) -> str: if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT: return self.DEFAULT_AIOHTTP_LOG_FORMAT elif re.search(r"%\([^\)]+\)", source_format): raise ValueError( "Gunicorn's style options in form of `%(name)s` are not " "supported for the log formatting. Please use aiohttp's " "format specification to configure access log formatting: " "http://docs.aiohttp.org/en/stable/logging.html" "#format-specification" ) else: return source_format class GunicornUVLoopWebWorker(GunicornWebWorker): def init_process(self) -> None: import uvloop # Setup uvloop policy, so that every # asyncio.get_event_loop() will create an instance # of uvloop event loop. asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) super().init_process() ================================================ FILE: docs/Makefile ================================================ # Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = -W --keep-going -n SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: rm -rf $(BUILDDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/aiohttp.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/aiohttp.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/aiohttp" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/aiohttp" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." latexpdfja: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." xml: $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." pseudoxml: $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." spelling: $(SPHINXBUILD) -b spelling $(ALLSPHINXOPTS) $(BUILDDIR)/spelling @echo @echo "Build finished." ================================================ FILE: docs/_static/css/logo-adjustments.css ================================================ .sphinxsidebarwrapper>h1.logo { display: none; } .sphinxsidebarwrapper>p.logo>a>img.logo { width: 65%; } ================================================ FILE: docs/abc.rst ================================================ .. module:: aiohttp.abc .. _aiohttp-abc: Abstract Base Classes ===================== Abstract routing ---------------- aiohttp has abstract classes for managing web interfaces. The most part of :mod:`aiohttp.web` is not intended to be inherited but few of them are. aiohttp.web is built on top of few concepts: *application*, *router*, *request* and *response*. *router* is a *pluggable* part: a library user may build a *router* from scratch, all other parts should work with new router seamlessly. :class:`aiohttp.abc.AbstractRouter` has the only mandatory method: :meth:`aiohttp.abc.AbstractRouter.resolve` coroutine. It must return an :class:`aiohttp.abc.AbstractMatchInfo` instance. If the requested URL handler is found :meth:`aiohttp.abc.AbstractMatchInfo.handler` is a :term:`web-handler` for requested URL and :attr:`aiohttp.abc.AbstractMatchInfo.http_exception` is ``None``. Otherwise :attr:`aiohttp.abc.AbstractMatchInfo.http_exception` is an instance of :exc:`~aiohttp.web.HTTPException` like *404: NotFound* or *405: Method Not Allowed*. :meth:`aiohttp.abc.AbstractMatchInfo.handler` raises :attr:`~aiohttp.abc.AbstractMatchInfo.http_exception` on call. .. class:: AbstractRouter Abstract router, :class:`aiohttp.web.Application` accepts it as *router* parameter and returns as :attr:`aiohttp.web.Application.router`. .. method:: resolve(request) :async: Performs URL resolving. It's an abstract method, should be overridden in *router* implementation. :param request: :class:`aiohttp.web.Request` instance for resolving, the request has :attr:`aiohttp.web.Request.match_info` equals to ``None`` at resolving stage. :return: :class:`aiohttp.abc.AbstractMatchInfo` instance. .. class:: AbstractMatchInfo Abstract *match info*, returned by :meth:`aiohttp.abc.AbstractRouter.resolve` call. .. attribute:: http_exception :exc:`aiohttp.web.HTTPException` if no match was found, ``None`` otherwise. .. method:: handler(request) :async: Abstract method performing :term:`web-handler` processing. :param request: :class:`aiohttp.web.Request` instance for resolving, the request has :attr:`aiohttp.web.Request.match_info` equals to ``None`` at resolving stage. :return: :class:`aiohttp.web.StreamResponse` or descendants. :raise: :class:`aiohttp.web.HTTPException` on error .. method:: expect_handler(request) :async: Abstract method for handling *100-continue* processing. Abstract Class Based Views -------------------------- For *class based view* support aiohttp has abstract :class:`AbstractView` class which is *awaitable* (may be uses like ``await Cls()`` or ``yield from Cls()`` and has a *request* as an attribute. .. class:: AbstractView An abstract class, base for all *class based views* implementations. Methods ``__iter__`` and ``__await__`` should be overridden. .. attribute:: request :class:`aiohttp.web.Request` instance for performing the request. Abstract Cookie Jar ------------------- .. class:: AbstractCookieJar The cookie jar instance is available as :attr:`aiohttp.ClientSession.cookie_jar`. The jar contains :class:`~http.cookies.Morsel` items for storing internal cookie data. API provides a count of saved cookies:: len(session.cookie_jar) These cookies may be iterated over:: for cookie in session.cookie_jar: print(cookie.key) print(cookie["domain"]) An abstract class for cookie storage. Implements :class:`collections.abc.Iterable` and :class:`collections.abc.Sized`. .. method:: update_cookies(cookies, response_url=None) Update cookies returned by server in ``Set-Cookie`` header. :param cookies: a :class:`collections.abc.Mapping` (e.g. :class:`dict`, :class:`~http.cookies.SimpleCookie`) or *iterable* of *pairs* with cookies returned by server's response. :param str response_url: URL of response, ``None`` for *shared cookies*. Regular cookies are coupled with server's URL and are sent only to this server, shared ones are sent in every client request. .. method:: filter_cookies(request_url) Return jar's cookies acceptable for URL and available in ``Cookie`` header for sending client requests for given URL. :param str response_url: request's URL for which cookies are asked. :return: :class:`http.cookies.SimpleCookie` with filtered cookies for given URL. .. method:: clear(predicate=None) Removes all cookies from the jar if the predicate is ``None``. Otherwise remove only those :class:`~http.cookies.Morsel` that ``predicate(morsel)`` returns ``True``. :param predicate: callable that gets :class:`~http.cookies.Morsel` as a parameter and returns ``True`` if this :class:`~http.cookies.Morsel` must be deleted from the jar. .. versionadded:: 3.8 .. method:: clear_domain(domain) Remove all cookies from the jar that belongs to the specified domain or its subdomains. :param str domain: domain for which cookies must be deleted from the jar. .. versionadded:: 3.8 Abstract Access Logger ------------------------------- .. class:: AbstractAccessLogger An abstract class, base for all :class:`aiohttp.web.RequestHandler` ``access_logger`` implementations Method ``log`` should be overridden. .. method:: log(request, response, time) :param request: :class:`aiohttp.web.Request` object. :param response: :class:`aiohttp.web.Response` object. :param float time: Time taken to serve the request. .. attribute:: enabled Return True if logger is enabled. Override this property if logging is disabled to avoid the overhead of calculating details to feed the logger. This property may be omitted if logging is always enabled. Abstract Resolver ------------------------------- .. class:: AbstractResolver An abstract class, base for all resolver implementations. Method ``resolve`` should be overridden. .. method:: resolve(host, port, family) Resolve host name to IP address. :param str host: host name to resolve. :param int port: port number. :param int family: socket family. :return: list of :class:`aiohttp.abc.ResolveResult` instances. .. method:: close() Release resolver. .. class:: ResolveResult Result of host name resolution. .. attribute:: hostname The host name that was provided. .. attribute:: host The IP address that was resolved. .. attribute:: port The port that was resolved. .. attribute:: family The address family that was resolved. .. attribute:: proto The protocol that was resolved. .. attribute:: flags The flags that were resolved. ================================================ FILE: docs/built_with.rst ================================================ .. _aiohttp-built-with: Built with aiohttp ================== aiohttp is used to build useful libraries built on top of it, and there's a page dedicated to list them: :ref:`aiohttp-3rd-party`. There are also projects that leverage the power of aiohttp to provide end-user tools, like command lines or software with full user interfaces. This page aims to list those projects. If you are using aiohttp in your software and if it's playing a central role, you can add it here in this list. You can also add a **Built with aiohttp** link somewhere in your project, pointing to ``_. * `Pulp `_ Platform for managing repositories of software packages and making them available to consumers. * `repo-peek `_ CLI tool to open a remote repo locally quickly. * `Molotov `_ Load testing tool. * `Arsenic `_ Async WebDriver. * `Home Assistant `_ Home Automation Platform. * `Backend.AI `_ Code execution API service. * `doh-proxy `_ DNS Over HTTPS Proxy. * `Mariner `_ Command-line torrent searcher. * `DEEPaaS API `_ REST API for Machine learning, Deep learning and artificial intelligence applications. * `BentoML `_ Machine Learning model serving framework * `salted `_ fast link check library (for HTML, Markdown, LaTeX, ...) with CLI * `Unofficial Tabdeal API `_ A package to communicate with the *Tabdeal* trading platform. ================================================ FILE: docs/changes.rst ================================================ .. _aiohttp_changes: ========= Changelog ========= .. only:: not is_release To be included in v\ |release| (if present) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. towncrier-draft-entries:: |release| [UNRELEASED DRAFT] Released versions ^^^^^^^^^^^^^^^^^ .. include:: ../CHANGES.rst :start-after: .. towncrier release notes start ================================================ FILE: docs/client.rst ================================================ .. _aiohttp-client: Client ====== .. currentmodule:: aiohttp The page contains all information about aiohttp Client API: .. toctree:: :name: client :maxdepth: 3 Quickstart Advanced Usage Client Middleware Cookbook Reference Tracing Reference The aiohttp Request Lifecycle ================================================ FILE: docs/client_advanced.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-client-advanced: Advanced Client Usage ===================== .. _aiohttp-client-session: Client Session -------------- :class:`ClientSession` is the heart and the main entry point for all client API operations. Create the session first, use the instance for performing HTTP requests and initiating WebSocket connections. The session contains a cookie storage and connection pool, thus cookies and connections are shared between HTTP requests sent by the same session. Custom Request Headers ---------------------- If you need to add HTTP headers to a request, pass them in a :class:`dict` to the *headers* parameter. For example, if you want to specify the content-type directly:: url = 'http://example.com/image' payload = b'GIF89a\x01\x00\x01\x00\x00\xff\x00,\x00\x00' b'\x00\x00\x01\x00\x01\x00\x00\x02\x00;' headers = {'content-type': 'image/gif'} await session.post(url, data=payload, headers=headers) You also can set default headers for all session requests:: headers={"Authorization": "Basic bG9naW46cGFzcw=="} async with aiohttp.ClientSession(headers=headers) as session: async with session.get("http://httpbin.org/headers") as r: json_body = await r.json() assert json_body['headers']['Authorization'] == \ 'Basic bG9naW46cGFzcw==' Typical use case is sending JSON body. You can specify content type directly as shown above, but it is more convenient to use special keyword ``json``:: await session.post(url, json={'example': 'text'}) For ``text/plain``:: await session.post(url, data='Привет, Мир!') Authentication -------------- Instead of setting the ``Authorization`` header directly, :class:`ClientSession` and individual request methods provide an ``auth`` argument. An instance of :class:`BasicAuth` can be passed in like this:: auth = BasicAuth(login="...", password="...") async with ClientSession(auth=auth) as session: ... For HTTP digest authentication, use the :class:`DigestAuthMiddleware` client middleware:: from aiohttp import ClientSession, DigestAuthMiddleware # Create the middleware with your credentials digest_auth = DigestAuthMiddleware(login="user", password="password") # Pass it to the ClientSession as a tuple async with ClientSession(middlewares=(digest_auth,)) as session: # The middleware will automatically handle auth challenges async with session.get("https://example.com/protected") as resp: print(await resp.text()) The :class:`DigestAuthMiddleware` implements HTTP Digest Authentication according to RFC 7616, providing a more secure alternative to Basic Authentication. It supports all standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options. The middleware automatically handles the authentication flow by intercepting 401 responses and retrying with proper credentials. Note that if the request is redirected and the redirect URL contains credentials, those credentials will supersede any previously set credentials. In other words, if ``http://user@example.com`` redirects to ``http://other_user@example.com``, the second request will be authenticated as ``other_user``. Providing both the ``auth`` parameter and authentication in the *initial* URL will result in a :exc:`ValueError`. For other authentication flows, the ``Authorization`` header can be set directly:: headers = {"Authorization": "Bearer eyJh...0M30"} async with ClientSession(headers=headers) as session: ... The authentication header for a session may be updated as and when required. For example:: session.headers["Authorization"] = "Bearer eyJh...1OH0" Note that a *copy* of the headers dictionary is set as an attribute when creating a :class:`ClientSession` instance (as a :class:`multidict.CIMultiDict` object). Updating the original dictionary does not have any effect. In cases where the authentication header value expires periodically, an :mod:`asyncio` task may be used to update the session's default headers in the background. .. note:: The ``Authorization`` header will be removed if you get redirected to a different host or protocol, except the case when HTTP → HTTPS redirect is performed on the same host. .. versionchanged:: 4.0 Started keeping the ``Authorization`` header during HTTP → HTTPS redirects when the host remains the same. .. _aiohttp-client-middleware: Client Middleware ----------------- The client supports middleware to intercept requests and responses. This can be useful for authentication, logging, request/response modification, retries etc. For more examples and common middleware patterns, see the :ref:`aiohttp-client-middleware-cookbook`. Creating a middleware ^^^^^^^^^^^^^^^^^^^^^ To create a middleware, define an async function (or callable class) that accepts a request object and a handler function, and returns a response. Middlewares must follow the :type:`ClientMiddlewareType` signature:: async def auth_middleware(req: ClientRequest, handler: ClientHandlerType) -> ClientResponse: req.headers["Authorization"] = get_auth_header() return await handler(req) Using Middlewares ^^^^^^^^^^^^^^^^^ You can apply middlewares to a client session or to individual requests:: # Apply to all requests in a session async with ClientSession(middlewares=(my_middleware,)) as session: resp = await session.get("http://example.com") # Apply to a specific request async with ClientSession() as session: resp = await session.get("http://example.com", middlewares=(my_middleware,)) Middleware Chaining ^^^^^^^^^^^^^^^^^^^ Multiple middlewares are applied in the order they are listed:: # Middlewares are applied in order: logging -> auth -> request async with ClientSession(middlewares=(logging_middleware, auth_middleware)) as session: async with session.get("http://example.com") as resp: ... A key aspect to understand about the middleware sequence is that the execution flow follows this pattern: 1. The first middleware in the list is called first and executes its code before calling the handler 2. The handler is the next middleware in the chain (or the request handler if there are no more middlewares) 3. When the handler returns a response, execution continues from the last middleware right after the handler call 4. This creates a nested "onion-like" pattern for execution For example, with ``middlewares=(middleware1, middleware2)``, the execution order would be: 1. Enter ``middleware1`` (pre-request code) 2. Enter ``middleware2`` (pre-request code) 3. Execute the actual request handler 4. Exit ``middleware2`` (post-response code) 5. Exit ``middleware1`` (post-response code) This flat structure means that a middleware is applied on each retry attempt inside the client's retry loop, not just once before all retries. This allows middleware to modify requests freshly on each retry attempt. For example, if we had a retry middleware and a logging middleware, and we want every retried request to be logged separately, then we'd need to specify ``middlewares=(retry_mw, logging_mw)``. If we reversed the order to ``middlewares=(logging_mw, retry_mw)``, then we'd only log once regardless of how many retries are done. .. note:: Client middleware is a powerful feature but should be used judiciously. Each middleware adds overhead to request processing. For simple use cases like adding static headers, you can often use request parameters (e.g., ``headers``) or session configuration instead. Custom Cookies -------------- To send your own cookies to the server, you can use the *cookies* parameter of :class:`ClientSession` constructor:: url = 'http://httpbin.org/cookies' cookies = {'cookies_are': 'working'} async with ClientSession(cookies=cookies) as session: async with session.get(url) as resp: assert await resp.json() == { "cookies": {"cookies_are": "working"}} .. note:: ``httpbin.org/cookies`` endpoint returns request cookies in JSON-encoded body. To access session cookies see :attr:`ClientSession.cookie_jar`. :class:`~aiohttp.ClientSession` may be used for sharing cookies between multiple requests:: async with aiohttp.ClientSession() as session: async with session.get( "http://httpbin.org/cookies/set?my_cookie=my_value", allow_redirects=False ) as resp: assert resp.cookies["my_cookie"].value == "my_value" async with session.get("http://httpbin.org/cookies") as r: json_body = await r.json() assert json_body["cookies"]["my_cookie"] == "my_value" Response Headers and Cookies ---------------------------- We can view the server's response :attr:`ClientResponse.headers` using a :class:`~multidict.CIMultiDictProxy`:: assert resp.headers == { 'ACCESS-CONTROL-ALLOW-ORIGIN': '*', 'CONTENT-TYPE': 'application/json', 'DATE': 'Tue, 15 Jul 2014 16:49:51 GMT', 'SERVER': 'gunicorn/18.0', 'CONTENT-LENGTH': '331', 'CONNECTION': 'keep-alive'} The dictionary is special, though: it's made just for HTTP headers. According to `RFC 7230 `_, HTTP Header names are case-insensitive. It also supports multiple values for the same key as HTTP protocol does. So, we can access the headers using any capitalization we want:: assert resp.headers['Content-Type'] == 'application/json' assert resp.headers.get('content-type') == 'application/json' All headers are converted from binary data using UTF-8 with ``surrogateescape`` option. That works fine on most cases but sometimes unconverted data is needed if a server uses nonstandard encoding. While these headers are malformed from :rfc:`7230` perspective they may be retrieved by using :attr:`ClientResponse.raw_headers` property:: assert resp.raw_headers == ( (b'SERVER', b'nginx'), (b'DATE', b'Sat, 09 Jan 2016 20:28:40 GMT'), (b'CONTENT-TYPE', b'text/html; charset=utf-8'), (b'CONTENT-LENGTH', b'12150'), (b'CONNECTION', b'keep-alive')) If a response contains some *HTTP Cookies*, you can quickly access them:: url = 'http://example.com/some/cookie/setting/url' async with session.get(url) as resp: print(resp.cookies['example_cookie_name']) .. note:: Response cookies contain only values, that were in ``Set-Cookie`` headers of the **last** request in redirection chain. To gather cookies between all redirection requests please use :ref:`aiohttp.ClientSession ` object. Redirection History ------------------- If a request was redirected, it is possible to view previous responses using the :attr:`~ClientResponse.history` attribute:: resp = await session.get('http://example.com/some/redirect/') assert resp.status == 200 assert resp.url == URL('http://example.com/some/other/url/') assert len(resp.history) == 1 assert resp.history[0].status == 301 assert resp.history[0].url == URL( 'http://example.com/some/redirect/') If no redirects occurred or ``allow_redirects`` is set to ``False``, history will be an empty sequence. Cookie Jar ---------- .. _aiohttp-client-cookie-safety: Cookie Safety ^^^^^^^^^^^^^ By default :class:`~aiohttp.ClientSession` uses strict version of :class:`aiohttp.CookieJar`. :rfc:`2109` explicitly forbids cookie accepting from URLs with IP address instead of DNS name (e.g. ``http://127.0.0.1:80/cookie``). It's good but sometimes for testing we need to enable support for such cookies. It should be done by passing ``unsafe=True`` to :class:`aiohttp.CookieJar` constructor:: jar = aiohttp.CookieJar(unsafe=True) session = aiohttp.ClientSession(cookie_jar=jar) .. _aiohttp-client-cookie-quoting-routine: Cookie Quoting Routine ^^^^^^^^^^^^^^^^^^^^^^ The client uses the :class:`~aiohttp.SimpleCookie` quoting routines conform to the :rfc:`2109`, which in turn references the character definitions from :rfc:`2068`. They provide a two-way quoting algorithm where any non-text character is translated into a 4 character sequence: a forward-slash followed by the three-digit octal equivalent of the character. Any ``\`` or ``"`` is quoted with a preceding ``\`` slash. Because of the way browsers really handle cookies (as opposed to what the RFC says) we also encode ``,`` and ``;``. Some backend systems does not support quoted cookies. You can skip this quotation routine by passing ``quote_cookie=False`` to the :class:`~aiohttp.CookieJar` constructor:: jar = aiohttp.CookieJar(quote_cookie=False) session = aiohttp.ClientSession(cookie_jar=jar) .. _aiohttp-client-dummy-cookie-jar: Dummy Cookie Jar ^^^^^^^^^^^^^^^^ Sometimes cookie processing is not desirable. For this purpose it's possible to pass :class:`aiohttp.DummyCookieJar` instance into client session:: jar = aiohttp.DummyCookieJar() session = aiohttp.ClientSession(cookie_jar=jar) Uploading pre-compressed data ----------------------------- To upload data that is already compressed before passing it to aiohttp, call the request function with the used compression algorithm name (usually ``deflate`` or ``gzip``) as the value of the ``Content-Encoding`` header:: async def my_coroutine(session, headers, my_data): data = zlib.compress(my_data) headers = {'Content-Encoding': 'deflate'} async with session.post('http://httpbin.org/post', data=data, headers=headers) pass Disabling content type validation for JSON responses ---------------------------------------------------- The standard explicitly restricts JSON ``Content-Type`` HTTP header to ``application/json`` or any extended form, e.g. ``application/vnd.custom-type+json``. Unfortunately, some servers send a wrong type, like ``text/html``. This can be worked around in two ways: 1. Pass the expected type explicitly (in this case checking will be strict, without the extended form support, so ``custom/xxx+type`` won't be accepted): ``await resp.json(content_type='custom/type')``. 2. Disable the check entirely: ``await resp.json(content_type=None)``. .. _aiohttp-client-tracing: Client Tracing -------------- The execution flow of a specific request can be followed attaching listeners coroutines to the signals provided by the :class:`TraceConfig` instance, this instance will be used as a parameter for the :class:`ClientSession` constructor having as a result a client that triggers the different signals supported by the :class:`TraceConfig`. By default any instance of :class:`ClientSession` class comes with the signals ability disabled. The following snippet shows how the start and the end signals of a request flow can be followed:: async def on_request_start( session, trace_config_ctx, params): print("Starting request") async def on_request_end(session, trace_config_ctx, params): print("Ending request") trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(on_request_start) trace_config.on_request_end.append(on_request_end) async with aiohttp.ClientSession( trace_configs=[trace_config]) as client: client.get('http://example.com/some/redirect/') The ``trace_configs`` is a list that can contain instances of :class:`TraceConfig` class that allow run the signals handlers coming from different :class:`TraceConfig` instances. The following example shows how two different :class:`TraceConfig` that have a different nature are installed to perform their job in each signal handle:: from mylib.traceconfig import AuditRequest from mylib.traceconfig import XRay async with aiohttp.ClientSession( trace_configs=[AuditRequest(), XRay()]) as client: client.get('http://example.com/some/redirect/') All signals take as a parameters first, the :class:`ClientSession` instance used by the specific request related to that signals and second, a :class:`~types.SimpleNamespace` instance called ``trace_config_ctx``. The ``trace_config_ctx`` object can be used to share the state through to the different signals that belong to the same request and to the same :class:`TraceConfig` class, perhaps:: async def on_request_start( session, trace_config_ctx, params): trace_config_ctx.start = asyncio.get_event_loop().time() async def on_request_end(session, trace_config_ctx, params): elapsed = asyncio.get_event_loop().time() - trace_config_ctx.start print("Request took {}".format(elapsed)) The ``trace_config_ctx`` param is by default a :class:`~types.SimpleNamespace` that is initialized at the beginning of the request flow. However, the factory used to create this object can be overwritten using the ``trace_config_ctx_factory`` constructor param of the :class:`TraceConfig` class. The ``trace_request_ctx`` param can given at the beginning of the request execution, accepted by all of the HTTP verbs, and will be passed as a keyword argument for the ``trace_config_ctx_factory`` factory. This param is useful to pass data that is only available at request time, perhaps:: async def on_request_start( session, trace_config_ctx, params): print(trace_config_ctx.trace_request_ctx) session.get('http://example.com/some/redirect/', trace_request_ctx={'foo': 'bar'}) .. seealso:: :ref:`aiohttp-client-tracing-reference` section for more information about the different signals supported. Connectors ---------- To tweak or change *transport* layer of requests you can pass a custom *connector* to :class:`~aiohttp.ClientSession` and family. For example:: conn = aiohttp.TCPConnector() session = aiohttp.ClientSession(connector=conn) .. note:: By default *session* object takes the ownership of the connector, among other things closing the connections once the *session* is closed. If you are keen on share the same *connector* through different *session* instances you must give the *connector_owner* parameter as **False** for each *session* instance. .. seealso:: :ref:`aiohttp-client-reference-connectors` section for more information about different connector types and configuration options. Limiting connection pool size ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ To limit amount of simultaneously opened connections you can pass *limit* parameter to *connector*:: conn = aiohttp.TCPConnector(limit=30) The example limits total amount of parallel connections to `30`. The default is `100`. If you explicitly want not to have limits, pass `0`. For example:: conn = aiohttp.TCPConnector(limit=0) To limit amount of simultaneously opened connection to the same endpoint (``(host, port, is_ssl)`` triple) you can pass *limit_per_host* parameter to *connector*:: conn = aiohttp.TCPConnector(limit_per_host=30) The example limits amount of parallel connections to the same to `30`. The default is `0` (no limit on per host bases). Tuning the DNS cache ^^^^^^^^^^^^^^^^^^^^ By default :class:`~aiohttp.TCPConnector` comes with the DNS cache table enabled, and resolutions will be cached by default for `10` seconds. This behavior can be changed either to change of the TTL for a resolution, as can be seen in the following example:: conn = aiohttp.TCPConnector(ttl_dns_cache=300) or disabling the use of the DNS cache table, meaning that all requests will end up making a DNS resolution, as the following example shows:: conn = aiohttp.TCPConnector(use_dns_cache=False) Resolving using custom nameservers ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ In order to specify the nameservers to when resolving the hostnames, :term:`aiodns` is required:: from aiohttp.resolver import AsyncResolver resolver = AsyncResolver(nameservers=["8.8.8.8", "8.8.4.4"]) conn = aiohttp.TCPConnector(resolver=resolver) Unix domain sockets ^^^^^^^^^^^^^^^^^^^ If your HTTP server uses UNIX domain sockets you can use :class:`~aiohttp.UnixConnector`:: conn = aiohttp.UnixConnector(path='/path/to/socket') session = aiohttp.ClientSession(connector=conn) Custom socket creation ^^^^^^^^^^^^^^^^^^^^^^ If the default socket is insufficient for your use case, pass an optional ``socket_factory`` to the :class:`~aiohttp.TCPConnector`, which implements :class:`SocketFactoryType`. This will be used to create all sockets for the lifetime of the class object. For example, we may want to change the conditions under which we consider a connection dead. The following would make all sockets respect 9*7200 = 18 hours:: import socket def socket_factory(addr_info): family, type_, proto, _, _ = addr_info sock = socket.socket(family=family, type=type_, proto=proto) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) return sock conn = aiohttp.TCPConnector(socket_factory=socket_factory) ``socket_factory`` may also be used for binding to the specific network interface on supported platforms:: def socket_factory(addr_info): family, type_, proto, _, _ = addr_info sock = socket.socket(family=family, type=type_, proto=proto) sock.setsockopt( socket.SOL_SOCKET, socket.SO_BINDTODEVICE, b'eth0' ) return sock conn = aiohttp.TCPConnector(socket_factory=socket_factory) Named pipes in Windows ^^^^^^^^^^^^^^^^^^^^^^ If your HTTP server uses Named pipes you can use :class:`~aiohttp.NamedPipeConnector`:: conn = aiohttp.NamedPipeConnector(path=r'\\.\pipe\') session = aiohttp.ClientSession(connector=conn) It will only work with the ProactorEventLoop SSL control for TCP sockets --------------------------- By default *aiohttp* uses strict checks for HTTPS protocol. Certification checks can be relaxed by setting *ssl* to ``False``:: r = await session.get('https://example.com', ssl=False) If you need to setup custom ssl parameters (use own certification files for example) you can create a :class:`ssl.SSLContext` instance and pass it into the :meth:`ClientSession.request` methods or set it for the entire session with ``ClientSession(connector=TCPConnector(ssl=ssl_context))``. There are explicit errors when ssl verification fails :class:`aiohttp.ClientConnectorSSLError`:: try: await session.get('https://expired.badssl.com/') except aiohttp.ClientConnectorSSLError as e: assert isinstance(e, ssl.SSLError) :class:`aiohttp.ClientConnectorCertificateError`:: try: await session.get('https://wrong.host.badssl.com/') except aiohttp.ClientConnectorCertificateError as e: assert isinstance(e, ssl.CertificateError) If you need to skip both ssl related errors :class:`aiohttp.ClientSSLError`:: try: await session.get('https://expired.badssl.com/') except aiohttp.ClientSSLError as e: assert isinstance(e, ssl.SSLError) try: await session.get('https://wrong.host.badssl.com/') except aiohttp.ClientSSLError as e: assert isinstance(e, ssl.CertificateError) Example: Use certifi ^^^^^^^^^^^^^^^^^^^^ By default, Python uses the system CA certificates. In rare cases, these may not be installed or Python is unable to find them, resulting in a error like `ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate` One way to work around this problem is to use the `certifi` package:: ssl_context = ssl.create_default_context(cafile=certifi.where()) async with ClientSession(connector=TCPConnector(ssl=ssl_context)) as sess: ... Example: Use self-signed certificate ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ If you need to verify *self-signed* certificates, you need to add a call to :meth:`ssl.SSLContext.load_cert_chain` with the key pair:: ssl_context = ssl.create_default_context() ssl_context.load_cert_chain("/path/to/client/public/device.pem", "/path/to/client/private/device.key") async with sess.get("https://example.com", ssl=ssl_context) as resp: ... Example: Verify certificate fingerprint ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ You may also verify certificates via *SHA256* fingerprint:: # Attempt to connect to https://www.python.org # with a pin to a bogus certificate: bad_fp = b'0'*64 exc = None try: r = await session.get('https://www.python.org', ssl=aiohttp.Fingerprint(bad_fp)) except aiohttp.FingerprintMismatch as e: exc = e assert exc is not None assert exc.expected == bad_fp # www.python.org cert's actual fingerprint assert exc.got == b'...' Note that this is the fingerprint of the DER-encoded certificate. If you have the certificate in PEM format, you can convert it to DER with e.g:: openssl x509 -in crt.pem -inform PEM -outform DER > crt.der .. note:: Tip: to convert from a hexadecimal digest to a binary byte-string, you can use :func:`binascii.unhexlify`. *ssl* parameter could be passed to :class:`TCPConnector` as default, the value from :meth:`ClientSession.get` and others override default. .. _aiohttp-client-proxy-support: Proxy support ------------- aiohttp supports plain HTTP proxies and HTTP proxies that can be upgraded to HTTPS via the HTTP CONNECT method. aiohttp has a limited support for proxies that must be connected to via ``https://`` — see the info box below for more details. To connect, use the *proxy* parameter:: async with aiohttp.ClientSession() as session: async with session.get("http://python.org", proxy="http://proxy.com") as resp: print(resp.status) It also supports proxy authorization:: async with aiohttp.ClientSession() as session: proxy_auth = aiohttp.BasicAuth('user', 'pass') async with session.get("http://python.org", proxy="http://proxy.com", proxy_auth=proxy_auth) as resp: print(resp.status) Authentication credentials can be passed in proxy URL:: session.get("http://python.org", proxy="http://user:pass@some.proxy.com") And you may set default proxy:: proxy_auth = aiohttp.BasicAuth('user', 'pass') async with aiohttp.ClientSession(proxy="http://proxy.com", proxy_auth=proxy_auth) as session: async with session.get("http://python.org") as resp: print(resp.status) Contrary to the ``requests`` library, it won't read environment variables by default. But you can do so by passing ``trust_env=True`` into :class:`aiohttp.ClientSession` constructor.:: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get("http://python.org") as resp: print(resp.status) .. note:: aiohttp uses :func:`urllib.request.getproxies` for reading the proxy configuration (e.g. from the *HTTP_PROXY* etc. environment variables) and applies them for the *HTTP*, *HTTPS*, *WS* and *WSS* schemes. Hosts defined in ``no_proxy`` will bypass the proxy. Proxy credentials are given from ``~/.netrc`` file if present (see :class:`aiohttp.ClientSession` for more details). .. attention:: As of now (Python 3.10), support for TLS in TLS is disabled for the transports that :py:mod:`asyncio` uses. If the further release of Python (say v3.11) toggles one attribute, it'll *just work™*. aiohttp v3.8 and higher is ready for this to happen and has code in place supports TLS-in-TLS, hence sending HTTPS requests over HTTPS proxy tunnels. ⚠️ For as long as your Python runtime doesn't declare the support for TLS-in-TLS, please don't file bugs with aiohttp but rather try to help the CPython upstream enable this feature. Meanwhile, if you *really* need this to work, there's a patch that may help you make it happen, include it into your app's code base: https://github.com/aio-libs/aiohttp/discussions/6044#discussioncomment-1432443. .. important:: When supplying a custom :py:class:`ssl.SSLContext` instance, bear in mind that it will be used not only to establish a TLS session with the HTTPS endpoint you're hitting but also to establish a TLS tunnel to the HTTPS proxy. To avoid surprises, make sure to set up the trust chain that would recognize TLS certificates used by both the endpoint and the proxy. .. _aiohttp-persistent-session: Persistent session ------------------ Even though creating a session on demand seems like a tempting idea, we advise against it. :class:`aiohttp.ClientSession` maintains a connection pool. Contained connections can be reused if necessary to gain some performance improvements. If you plan on reusing the session, a.k.a. creating **persistent session**, you can use either :ref:`aiohttp-web-signals` or :ref:`aiohttp-web-cleanup-ctx`. If possible we advise using :ref:`aiohttp-web-cleanup-ctx`, as it results in more compact code:: session = aiohttp.web.AppKey("session", aiohttp.ClientSession) @contextlib.asynccontextmanager async def persistent_session(app): app[persistent_session] = session = aiohttp.ClientSession() yield await session.close() async def my_request_handler(request): sess = request.app[session] async with sess.get("http://python.org") as resp: print(resp.status) app.cleanup_ctx.append(persistent_session) This approach can be successfully used to define numerous sessions given certain requirements. It benefits from having a single location where :class:`aiohttp.ClientSession` instances are created and where artifacts such as :class:`aiohttp.BaseConnector` can be safely shared between sessions if needed. In the end all you have to do is to close all sessions after the `yield` statement:: async def multiple_sessions(app): app[persistent_session_1] = session_1 = aiohttp.ClientSession() app[persistent_session_2] = session_2 = aiohttp.ClientSession() app[persistent_session_3] = session_3 = aiohttp.ClientSession() yield await asyncio.gather( session_1.close(), session_2.close(), session_3.close(), ) Graceful Shutdown ----------------- When :class:`ClientSession` closes at the end of an ``async with`` block (or through a direct :meth:`ClientSession.close` call), the underlying connection remains open due to asyncio internal details. In practice, the underlying connection will close after a short while. However, if the event loop is stopped before the underlying connection is closed, a ``ResourceWarning: unclosed transport`` warning is emitted (when warnings are enabled). To avoid this situation, a small delay must be added before closing the event loop to allow any open underlying connections to close. For a :class:`ClientSession` without SSL, a simple zero-sleep (``await asyncio.sleep(0)``) will suffice:: async def read_website(): async with aiohttp.ClientSession() as session: async with session.get('http://example.org/') as resp: await resp.read() # Zero-sleep to allow underlying connections to close await asyncio.sleep(0) For a :class:`ClientSession` with SSL, the application must wait a short duration before closing:: ... # Wait 250 ms for the underlying SSL connections to close await asyncio.sleep(0.250) Note that the appropriate amount of time to wait will vary from application to application. All of this will eventually become obsolete when the asyncio internals are changed so that aiohttp itself can wait on the underlying connection to close. Please follow issue `#1925 `_ for the progress on this. HTTP Pipelining --------------- aiohttp does not support HTTP/HTTPS pipelining. Character Set Detection ----------------------- If you encounter a :exc:`UnicodeDecodeError` when using :meth:`ClientResponse.text` this may be because the response does not include the charset needed to decode the body. If you know the correct encoding for a request, you can simply specify the encoding as a parameter (e.g. ``resp.text("windows-1252")``). Alternatively, :class:`ClientSession` accepts a ``fallback_charset_resolver`` parameter which can be used to introduce charset guessing functionality. When a charset is not found in the Content-Type header, this function will be called to get the charset encoding. For example, this can be used with the ``chardetng_py`` library.:: from chardetng_py import detect def charset_resolver(resp: ClientResponse, body: bytes) -> str: tld = resp.url.host.rsplit(".", maxsplit=1)[-1] return detect(body, allow_utf8=True, tld=tld.encode()) ClientSession(fallback_charset_resolver=charset_resolver) Or, if ``chardetng_py`` doesn't work for you, then ``charset-normalizer`` is another option:: from charset_normalizer import detect ClientSession(fallback_charset_resolver=lambda r, b: detect(b)["encoding"] or "utf-8") ================================================ FILE: docs/client_middleware_cookbook.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-client-middleware-cookbook: Client Middleware Cookbook ========================== This cookbook provides examples of how client middlewares can be used for common use cases. Simple Retry Middleware ----------------------- It's very easy to create middlewares that can retry a connection on a given condition: .. literalinclude:: code/client_middleware_cookbook.py :pyobject: retry_middleware .. warning:: It is recommended to ensure loops are bounded (e.g. using a ``for`` loop) to avoid creating an infinite loop. Logging to an external service ------------------------------ If we needed to log our requests via an API call to an external server or similar, we could create a simple middleware like this: .. literalinclude:: code/client_middleware_cookbook.py :pyobject: api_logging_middleware .. warning:: Using the same session from within a middleware can cause infinite recursion if that request gets processed again by the middleware. To avoid such recursion a middleware should typically make requests with ``middlewares=()`` or else contain some condition to stop the request triggering the same logic when it is processed again by the middleware (e.g by whitelisting the API domain of the request). Token Refresh Middleware ------------------------ If you need to refresh access tokens to continue accessing an API, this is also a good candidate for a middleware. For example, you could check for a 401 response, then refresh the token and retry: .. literalinclude:: code/client_middleware_cookbook.py :pyobject: TokenRefresh401Middleware If you have an expiry time for the token, you could refresh at the expiry time, to avoid the failed request: .. literalinclude:: code/client_middleware_cookbook.py :pyobject: TokenRefreshExpiryMiddleware Or you could even refresh preemptively in a background task to avoid any API delays. This is probably more efficient to implement without a middleware: .. literalinclude:: code/client_middleware_cookbook.py :pyobject: token_refresh_preemptively_example :lines: 2- :dedent: Or combine the above approaches to create a more robust solution. .. note:: These can also be adjusted to handle proxy auth by modifying :attr:`ClientRequest.proxy_headers`. Server-side Request Forgery Protection -------------------------------------- To provide protection against server-side request forgery, we could blacklist any internal IPs or domains. We could create a middleware that rejects requests made to a blacklist: .. literalinclude:: code/client_middleware_cookbook.py :pyobject: ssrf_middleware .. warning:: The above example is simplified for demonstration purposes. A production-ready implementation should also check IPv6 addresses (``::1``), private IP ranges, link-local addresses, and other internal hostnames. Consider using a well-tested library for SSRF protection in production environments. If you know that your services correctly reject requests with an incorrect `Host` header, then that may provide sufficient protection. Otherwise, we still have a concern with an attacker's own domain resolving to a blacklisted IP. To provide complete protection, we can also create a custom resolver: .. literalinclude:: code/client_middleware_cookbook.py :pyobject: SSRFConnector Using both of these together in a session should provide full SSRF protection. Best Practices -------------- .. important:: **Request-level middlewares replace session middlewares**: When you pass ``middlewares`` to ``request()`` or its convenience methods (``get()``, ``post()``, etc.), it completely replaces the session-level middlewares, rather than extending them. This differs from other parameters like ``headers``, which are merged. .. code-block:: python session = ClientSession(middlewares=[middleware_session]) # Session middleware is used await session.get("http://example.com") # Session middleware is NOT used, only request middleware await session.get("http://example.com", middlewares=[middleware_request]) # To use both, explicitly pass both await session.get( "http://example.com", middlewares=[middleware_session, middleware_request] ) 1. **Keep middleware focused**: Each middleware should have a single responsibility. 2. **Order matters**: Middlewares execute in the order they're listed. Place logging first, authentication before retry, etc. 3. **Avoid infinite recursion**: When making HTTP requests inside middleware, either: - Use ``middlewares=()`` to disable middleware for internal requests - Check the request URL/host to skip middleware for specific endpoints - Use a separate session for internal requests 4. **Handle errors gracefully**: Don't let middleware errors break the request flow unless absolutely necessary. 5. **Use bounded loops**: Always use ``for`` loops with a maximum iteration count instead of unbounded ``while`` loops to prevent infinite retries. 6. **Consider performance**: Each middleware adds overhead. For simple cases like adding static headers, consider using session or request parameters instead. 7. **Test thoroughly**: Middleware can affect all requests in subtle ways. Test edge cases like network errors, timeouts, and concurrent requests. See Also -------- - :ref:`aiohttp-client-middleware` - Core middleware documentation - :ref:`aiohttp-client-advanced` - Advanced client usage - :class:`DigestAuthMiddleware` - Built-in digest authentication middleware ================================================ FILE: docs/client_quickstart.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-client-quickstart: =================== Client Quickstart =================== Eager to get started? This page gives a good introduction in how to get started with aiohttp client API. First, make sure that aiohttp is :ref:`installed ` and *up-to-date* Let's get started with some simple examples. Make a Request ============== Begin by importing the aiohttp module, and asyncio:: import aiohttp import asyncio Now, let's try to get a web-page. For example let's query ``http://httpbin.org/get``:: async def main(): async with aiohttp.ClientSession() as session: async with session.get('http://httpbin.org/get') as resp: print(resp.status) print(await resp.text()) asyncio.run(main()) Now, we have a :class:`ClientSession` called ``session`` and a :class:`ClientResponse` object called ``resp``. We can get all the information we need from the response. The mandatory parameter of :meth:`ClientSession.get` coroutine is an HTTP *url* (:class:`str` or class:`yarl.URL` instance). In order to make an HTTP POST request use :meth:`ClientSession.post` coroutine:: session.post('http://httpbin.org/post', data=b'data') Other HTTP methods are available as well:: session.put('http://httpbin.org/put', data=b'data') session.delete('http://httpbin.org/delete') session.head('http://httpbin.org/get') session.options('http://httpbin.org/get') session.patch('http://httpbin.org/patch', data=b'data') To make several requests to the same site more simple, the parameter ``base_url`` of :class:`ClientSession` constructor can be used. For example to request different endpoints of ``http://httpbin.org`` can be used the following code:: async with aiohttp.ClientSession('http://httpbin.org') as session: async with session.get('/get'): pass async with session.post('/post', data=b'data'): pass async with session.put('/put', data=b'data'): pass .. note:: Don't create a session per request. Most likely you need a session per application which performs all requests together. More complex cases may require a session per site, e.g. one for Github and other one for Facebook APIs. Anyway making a session for every request is a **very bad** idea. A session contains a connection pool inside. Connection reusage and keep-alive (both are on by default) may speed up total performance. You may find more information about creating persistent sessions in :ref:`aiohttp-persistent-session`. A session context manager usage is not mandatory but ``await session.close()`` method should be called in this case, e.g.:: session = aiohttp.ClientSession() async with session.get('...'): # ... await session.close() Passing Parameters In URLs ========================== You often want to send some sort of data in the URL's query string. If you were constructing the URL by hand, this data would be given as key/value pairs in the URL after a question mark, e.g. ``httpbin.org/get?key=val``. aiohttp allows you to provide these arguments as a :class:`dict`, using the ``params`` keyword argument. As an example, if you wanted to pass ``key1=value1`` and ``key2=value2`` to ``httpbin.org/get``, you would use the following code:: params = {'key1': 'value1', 'key2': 'value2'} async with session.get('http://httpbin.org/get', params=params) as resp: expect = 'http://httpbin.org/get?key1=value1&key2=value2' assert str(resp.url) == expect You can see that the URL has been correctly encoded by printing the URL. For sending data with multiple values for the same key :class:`~multidict.MultiDict` may be used; the library support nested lists (``{'key': ['value1', 'value2']}``) alternative as well. It is also possible to pass a list of 2 item tuples as parameters, in that case you can specify multiple values for each key:: params = [('key', 'value1'), ('key', 'value2')] async with session.get('http://httpbin.org/get', params=params) as r: expect = 'http://httpbin.org/get?key=value2&key=value1' assert str(r.url) == expect You can also pass :class:`str` content as param, but beware -- content is not encoded by library. Note that ``+`` is not encoded:: async with session.get('http://httpbin.org/get', params='key=value+1') as r: assert str(r.url) == 'http://httpbin.org/get?key=value+1' .. note:: *aiohttp* internally performs URL canonicalization before sending request. Canonicalization encodes *host* part by :term:`IDNA` codec and applies :term:`requoting` to *path* and *query* parts. For example ``URL('http://example.com/путь/%30?a=%31')`` is converted to ``URL('http://example.com/%D0%BF%D1%83%D1%82%D1%8C/0?a=1')``. Sometimes canonicalization is not desirable if server accepts exact representation and does not requote URL itself. To disable canonicalization use ``encoded=True`` parameter for URL construction:: await session.get( URL('http://example.com/%30', encoded=True)) .. warning:: Passing *params* overrides ``encoded=True``, never use both options. Response Content and Status Code ================================ We can read the content of the server's response and its status code. Consider the GitHub time-line again:: async with session.get('https://api.github.com/events') as resp: print(resp.status) print(await resp.text()) prints out something like:: 200 '[{"created_at":"2015-06-12T14:06:22Z","public":true,"actor":{... ``aiohttp`` automatically decodes the content from the server. You can specify custom encoding for the :meth:`~ClientResponse.text` method:: await resp.text(encoding='windows-1251') Binary Response Content ======================= You can also access the response body as bytes, for non-text requests:: print(await resp.read()) :: b'[{"created_at":"2015-06-12T14:06:22Z","public":true,"actor":{... The ``gzip`` and ``deflate`` transfer-encodings are automatically decoded for you. You can enable ``brotli`` transfer-encodings support, just install `Brotli `_ or `brotlicffi `_. You can enable ``zstd`` transfer-encodings support, install `backports.zstd `_. If you are using Python >= 3.14, no dependency should be required. JSON Request ============ Any of session's request methods like :func:`request`, :meth:`ClientSession.get`, :meth:`ClientSession.post` etc. accept `json` parameter:: async with aiohttp.ClientSession() as session: await session.post(url, json={'test': 'object'}) By default session uses python's standard :mod:`json` module for serialization. But it is possible to use a different ``serializer``. :class:`ClientSession` accepts ``json_serialize`` and ``json_serialize_bytes`` parameters:: import orjson async with aiohttp.ClientSession( json_serialize_bytes=orjson.dumps) as session: await session.post(url, json={'test': 'object'}) .. note:: ``orjson`` library is faster than standard :mod:`json` and is actively maintained. Since ``orjson.dumps`` returns :class:`bytes`, pass it via the ``json_serialize_bytes`` parameter to avoid unnecessary encoding/decoding overhead. JSON Response Content ===================== There's also a built-in JSON decoder, in case you're dealing with JSON data:: async with session.get('https://api.github.com/events') as resp: print(await resp.json()) In case that JSON decoding fails, :meth:`~ClientResponse.json` will raise an exception. It is possible to specify custom encoding and decoder functions for the :meth:`~ClientResponse.json` call. .. note:: The methods above reads the whole response body into memory. If you are planning on reading lots of data, consider using the streaming response method documented below. Streaming Response Content ========================== While methods :meth:`~ClientResponse.read`, :meth:`~ClientResponse.json` and :meth:`~ClientResponse.text` are very convenient you should use them carefully. All these methods load the whole response in memory. For example if you want to download several gigabyte sized files, these methods will load all the data in memory. Instead you can use the :attr:`~ClientResponse.content` attribute. It is an instance of the :class:`aiohttp.StreamReader` class. The ``gzip`` and ``deflate`` transfer-encodings are automatically decoded for you:: async with session.get('https://api.github.com/events') as resp: await resp.content.read(10) In general, however, you should use a pattern like this to save what is being streamed to a file:: with open(filename, 'wb') as fd: async for chunk in resp.content.iter_chunked(chunk_size): fd.write(chunk) It is not possible to use :meth:`~ClientResponse.read`, :meth:`~ClientResponse.json` and :meth:`~ClientResponse.text` after explicit reading from :attr:`~ClientResponse.content`. More complicated POST requests ============================== Typically, you want to send some form-encoded data -- much like an HTML form. To do this, simply pass a dictionary to the *data* argument. Your dictionary of data will automatically be form-encoded when the request is made:: payload = {'key1': 'value1', 'key2': 'value2'} async with session.post('http://httpbin.org/post', data=payload) as resp: print(await resp.text()) :: { ... "form": { "key2": "value2", "key1": "value1" }, ... } If you want to send data that is not form-encoded you can do it by passing a :class:`bytes` instead of a :class:`dict`. This data will be posted directly and content-type set to 'application/octet-stream' by default:: async with session.post(url, data=b'\x00Binary-data\x00') as resp: ... If you want to send JSON data:: async with session.post(url, json={'example': 'test'}) as resp: ... To send text with appropriate content-type just use ``data`` argument:: async with session.post(url, data='Тест') as resp: ... POST a Multipart-Encoded File ============================= To upload Multipart-encoded files:: url = 'http://httpbin.org/post' files = {'file': open('report.xls', 'rb')} await session.post(url, data=files) You can set the ``filename`` and ``content_type`` explicitly:: url = 'http://httpbin.org/post' data = aiohttp.FormData() data.add_field('file', open('report.xls', 'rb'), filename='report.xls', content_type='application/vnd.ms-excel') await session.post(url, data=data) If you pass a file object as data parameter, aiohttp will stream it to the server automatically. Check :class:`~aiohttp.StreamReader` for supported format information. .. seealso:: :ref:`aiohttp-multipart` Streaming uploads ================= :mod:`aiohttp` supports multiple types of streaming uploads, which allows you to send large files without reading them into memory. As a simple case, simply provide a file-like object for your body:: with open("massive-body", "rb") as f: await session.post("https://httpbin.org/post", data=f) Or you can provide an *asynchronous generator*, for example to generate data on the fly:: async def data_generator(): for i in range(10): yield f"line {i}\n".encode() async with session.post("https://httpbin.org/post", data=data_generator()) as resp: print(await resp.text()) .. warning:: Async generators and other non-rewindable data sources (such as :class:`~aiohttp.StreamReader`) cannot be replayed if a redirect occurs (for example, HTTP 307 or 308). If the request body has already been streamed, :mod:`aiohttp` raises :class:`~aiohttp.ClientPayloadError`. If your endpoint may redirect, either: * Pass a seekable file-like object or :class:`bytes`. * Disable redirects with ``allow_redirects=False`` and handle them manually. Because the :attr:`~aiohttp.ClientResponse.content` attribute is a :class:`~aiohttp.StreamReader` (provides async iterator protocol), you can chain get and post requests together:: resp = await session.get('http://python.org') await session.post('http://httpbin.org/post', data=resp.content) .. _aiohttp-client-websockets: WebSockets ========== :mod:`aiohttp` works with client websockets out-of-the-box. You have to use the :meth:`aiohttp.ClientSession.ws_connect` coroutine for client websocket connection. It accepts a *url* as a first parameter and returns :class:`ClientWebSocketResponse`, with that object you can communicate with websocket server using response's methods:: async with session.ws_connect('http://example.org/ws') as ws: async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: if msg.data == 'close cmd': await ws.close() break else: await ws.send_str(msg.data + '/answer') elif msg.type == aiohttp.WSMsgType.ERROR: break You **must** use the only websocket task for both reading (e.g. ``await ws.receive()`` or ``async for msg in ws:``) and writing but may have multiple writer tasks which can only send data asynchronously (by ``await ws.send_str('data')`` for example). .. _aiohttp-client-timeouts: Timeouts ======== Timeout settings are stored in :class:`ClientTimeout` data structure. By default *aiohttp* uses a *total* 300 seconds (5min) timeout, it means that the whole operation should finish in 5 minutes. In order to allow time for DNS fallback, the default ``sock_connect`` timeout is 30 seconds. The value could be overridden by *timeout* parameter for the session (specified in seconds):: timeout = aiohttp.ClientTimeout(total=60) async with aiohttp.ClientSession(timeout=timeout) as session: ... Timeout could be overridden for a request like :meth:`ClientSession.get`:: async with session.get(url, timeout=timeout) as resp: ... Supported :class:`ClientTimeout` fields are: ``total`` The maximal number of seconds for the whole operation including connection establishment, request sending and response reading. ``connect`` The maximal number of seconds for connection establishment of a new connection or for waiting for a free connection from a pool if pool connection limits are exceeded. ``sock_connect`` The maximal number of seconds for connecting to a peer for a new connection, not given from a pool. ``sock_read`` The maximal number of seconds allowed for period between reading a new data portion from a peer. ``ceil_threshold`` The threshold value to trigger ceiling of absolute timeout values. All fields are floats, ``None`` or ``0`` disables a particular timeout check, see the :class:`ClientTimeout` reference for defaults and additional details. Thus the default timeout is:: aiohttp.ClientTimeout(total=5*60, connect=None, sock_connect=None, sock_read=None, ceil_threshold=5) .. note:: *aiohttp* **ceils** timeout if the value is equal or greater than 5 seconds. The timeout expires at the next integer second greater than ``current_time + timeout``. The ceiling is done for the sake of optimization, when many concurrent tasks are scheduled to wake-up at the almost same but different absolute times. It leads to very many event loop wakeups, which kills performance. The optimization shifts absolute wakeup times by scheduling them to exactly the same time as other neighbors, the loop wakes up once-per-second for timeout expiration. Smaller timeouts are not rounded to help testing; in the real life network timeouts usually greater than tens of seconds. However, the default threshold value of 5 seconds can be configured using the ``ceil_threshold`` parameter. ================================================ FILE: docs/client_reference.rst ================================================ .. _aiohttp-client-reference: Client Reference ================ .. currentmodule:: aiohttp Client Session -------------- Client session is the recommended interface for making HTTP requests. Session encapsulates a *connection pool* (*connector* instance) and supports keepalives by default. Unless you are connecting to a large, unknown number of different servers over the lifetime of your application, it is suggested you use a single session for the lifetime of your application to benefit from connection pooling. Usage example:: import aiohttp import asyncio async def fetch(client): async with client.get('http://python.org') as resp: assert resp.status == 200 return await resp.text() async def main(): async with aiohttp.ClientSession() as client: html = await fetch(client) print(html) asyncio.run(main()) The client session supports the context manager protocol for self closing. .. class:: ClientSession(base_url=None, *, \ connector=None, cookies=None, \ headers=None, skip_auto_headers=None, \ auth=None, json_serialize=json.dumps, \ request_class=ClientRequest, \ response_class=ClientResponse, \ ws_response_class=ClientWebSocketResponse, \ version=aiohttp.HttpVersion11, \ cookie_jar=None, \ connector_owner=True, \ raise_for_status=False, \ timeout=sentinel, \ auto_decompress=True, \ trust_env=False, \ requote_redirect_url=True, \ trace_configs=None, \ middlewares=(), \ read_bufsize=2**16, \ max_line_size=8190, \ max_field_size=8190, \ max_headers=128, \ fallback_charset_resolver=lambda r, b: "utf-8", \ ssl_shutdown_timeout=0) :canonical: aiohttp.client.ClientSession The class for creating client sessions and making requests. :param base_url: Base part of the URL (optional) If set, allows to join a base part to relative URLs in request calls. If the URL has a path it must have a trailing ``/`` (as in https://docs.aiohttp.org/en/stable/). Note that URL joining follows :rfc:`3986`. This means, in the most common case the request URLs should have no leading slash, e.g.:: session = ClientSession(base_url="http://example.com/foo/") await session.request("GET", "bar") # request for http://example.com/foo/bar await session.request("GET", "/bar") # request for http://example.com/bar .. versionadded:: 3.8 .. versionchanged:: 3.12 Added support for overriding the base URL with an absolute one in client sessions. :param aiohttp.BaseConnector connector: BaseConnector sub-class instance to support connection pooling. :param dict cookies: Cookies to send with the request (optional) :param headers: HTTP Headers to send with every request (optional). May be either *iterable of key-value pairs* or :class:`~collections.abc.Mapping` (e.g. :class:`dict`, :class:`~multidict.CIMultiDict`). :param skip_auto_headers: set of headers for which autogeneration should be skipped. *aiohttp* autogenerates headers like ``User-Agent`` or ``Content-Type`` if these headers are not explicitly passed. Using ``skip_auto_headers`` parameter allows to skip that generation. Note that ``Content-Length`` autogeneration can't be skipped. Iterable of :class:`str` or :class:`~multidict.istr` (optional) :param aiohttp.BasicAuth auth: an object that represents HTTP Basic Authorization (optional). It will be included with any request. However, if the ``_base_url`` parameter is set, the request URL's origin must match the base URL's origin; otherwise, the default auth will not be included. :param collections.abc.Callable json_serialize: Json *serializer* callable. By default :func:`json.dumps` function. :param aiohttp.ClientRequest request_class: Custom class to use for client requests. :param ClientResponse response_class: Custom class to use for client responses. :param ClientWebSocketResponse ws_response_class: Custom class to use for websocket responses. :param version: supported HTTP version, ``HTTP 1.1`` by default. :param cookie_jar: Cookie Jar, :class:`~aiohttp.abc.AbstractCookieJar` instance. By default every session instance has own private cookie jar for automatic cookies processing but user may redefine this behavior by providing own jar implementation. One example is not processing cookies at all when working in proxy mode. If no cookie processing is needed, a :class:`aiohttp.DummyCookieJar` instance can be provided. :param bool connector_owner: Close connector instance on session closing. Setting the parameter to ``False`` allows to share connection pool between sessions without sharing session state: cookies etc. :param bool raise_for_status: Automatically call :meth:`ClientResponse.raise_for_status` for each response, ``False`` by default. This parameter can be overridden when making a request, e.g.:: client_session = aiohttp.ClientSession(raise_for_status=True) resp = await client_session.get(url, raise_for_status=False) async with resp: assert resp.status == 200 Set the parameter to ``True`` if you need ``raise_for_status`` for most of cases but override ``raise_for_status`` for those requests where you need to handle responses with status 400 or higher. :param timeout: a :class:`ClientTimeout` settings structure, 300 seconds (5min) total timeout, 30 seconds socket connect timeout by default. .. versionadded:: 3.3 .. versionchanged:: 3.10.9 The default value for the ``sock_connect`` timeout has been changed to 30 seconds. :param bool auto_decompress: Automatically decompress response body (``True`` by default). .. versionadded:: 2.3 :param bool trust_env: Trust environment settings for proxy configuration if the parameter is ``True`` (``False`` by default). See :ref:`aiohttp-client-proxy-support` for more information. Get proxy credentials from ``~/.netrc`` file if present. Get HTTP Basic Auth credentials from :file:`~/.netrc` file if present. If :envvar:`NETRC` environment variable is set, read from file specified there rather than from :file:`~/.netrc`. .. seealso:: ``.netrc`` documentation: https://www.gnu.org/software/inetutils/manual/html_node/The-_002enetrc-file.html .. versionadded:: 2.3 .. versionchanged:: 3.0 Added support for ``~/.netrc`` file. .. versionchanged:: 3.9 Added support for reading HTTP Basic Auth credentials from :file:`~/.netrc` file. :param bool requote_redirect_url: Apply *URL requoting* for redirection URLs if automatic redirection is enabled (``True`` by default). .. versionadded:: 3.5 :param trace_configs: A list of :class:`TraceConfig` instances used for client tracing. ``None`` (default) is used for request tracing disabling. See :ref:`aiohttp-client-tracing-reference` for more information. :param middlewares: A sequence of middleware instances to apply to all session requests. Each middleware must match the :type:`ClientMiddlewareType` signature. ``()`` (empty tuple, default) is used when no middleware is needed. See :ref:`aiohttp-client-middleware` for more information. .. versionadded:: 3.12 :param int read_bufsize: Size of the read buffer (:attr:`ClientResponse.content`). 64 KiB by default. .. versionadded:: 3.7 :param int max_line_size: Maximum allowed size of lines in responses. :param int max_field_size: Maximum allowed size of header name and value combined in responses. :param int max_headers: Maximum number of headers and trailers combined in responses. :param Callable[[ClientResponse,bytes],str] fallback_charset_resolver: A :term:`callable` that accepts a :class:`ClientResponse` and the :class:`bytes` contents, and returns a :class:`str` which will be used as the encoding parameter to :meth:`bytes.decode()`. This function will be called when the charset is not known (e.g. not specified in the Content-Type header). The default function simply defaults to ``utf-8``. .. versionadded:: 3.8.6 :param float ssl_shutdown_timeout: **(DEPRECATED)** This parameter is deprecated and will be removed in aiohttp 4.0. Grace period for SSL shutdown handshake on TLS connections when the connector is closed (``0`` seconds by default). By default (``0``), SSL connections are aborted immediately when the connector is closed, without performing the shutdown handshake. During normal operation, SSL connections use Python's default SSL shutdown behavior. Setting this to a positive value (e.g., ``0.1``) will perform a graceful shutdown when closing the connector, notifying the remote peer which can help prevent "connection reset" errors at the cost of additional cleanup time. This timeout is passed to the underlying :class:`TCPConnector` when one is created automatically. Note: On Python versions prior to 3.11, only a value of ``0`` is supported; other values will trigger a warning. .. versionadded:: 3.12.5 .. versionchanged:: 3.12.11 Changed default from ``0.1`` to ``0`` to abort SSL connections immediately when the connector is closed. Added support for ``ssl_shutdown_timeout=0`` on all Python versions. A :exc:`RuntimeWarning` is issued when non-zero values are passed on Python < 3.11. .. deprecated:: 3.12.11 This parameter is deprecated and will be removed in aiohttp 4.0. .. attribute:: closed ``True`` if the session has been closed, ``False`` otherwise. A read-only property. .. attribute:: connector :class:`aiohttp.BaseConnector` derived instance used for the session. A read-only property. .. attribute:: cookie_jar The session cookies, :class:`~aiohttp.abc.AbstractCookieJar` instance. Gives access to cookie jar's content and modifiers. A read-only property. .. attribute:: requote_redirect_url aiohttp re quote's redirect urls by default, but some servers require exact url from location header. To disable *re-quote* system set :attr:`requote_redirect_url` attribute to ``False``. .. versionadded:: 2.1 .. note:: This parameter affects all subsequent requests. .. deprecated:: 3.5 The attribute modification is deprecated. .. attribute:: loop A loop instance used for session creation. A read-only property. .. deprecated:: 3.5 .. attribute:: timeout Default client timeouts, :class:`ClientTimeout` instance. The value can be tuned by passing *timeout* parameter to :class:`ClientSession` constructor. .. versionadded:: 3.7 .. attribute:: headers HTTP Headers that sent with every request May be either *iterable of key-value pairs* or :class:`~collections.abc.Mapping` (e.g. :class:`dict`, :class:`~multidict.CIMultiDict`). .. versionadded:: 3.7 .. attribute:: skip_auto_headers Set of headers for which autogeneration skipped. :class:`frozenset` of :class:`str` or :class:`~multidict.istr` (optional) .. versionadded:: 3.7 .. attribute:: auth An object that represents HTTP Basic Authorization. :class:`~aiohttp.BasicAuth` (optional) .. versionadded:: 3.7 .. attribute:: json_serialize Json serializer callable. By default :func:`json.dumps` function. .. versionadded:: 3.7 .. attribute:: connector_owner Should connector be closed on session closing :class:`bool` (optional) .. versionadded:: 3.7 .. attribute:: raise_for_status Should :meth:`ClientResponse.raise_for_status` be called for each response Either :class:`bool` or :class:`collections.abc.Callable` .. versionadded:: 3.7 .. attribute:: auto_decompress Should the body response be automatically decompressed :class:`bool` default is ``True`` .. versionadded:: 3.7 .. attribute:: trust_env Trust environment settings for proxy configuration or ~/.netrc file if present. See :ref:`aiohttp-client-proxy-support` for more information. :class:`bool` default is ``False`` .. versionadded:: 3.7 .. attribute:: trace_configs A list of :class:`TraceConfig` instances used for client tracing. ``None`` (default) is used for request tracing disabling. See :ref:`aiohttp-client-tracing-reference` for more information. .. versionadded:: 3.7 .. method:: request(method, url, *, params=None, data=None, json=None,\ cookies=None, headers=None, skip_auto_headers=None, \ auth=None, allow_redirects=True,\ max_redirects=10,\ compress=None, chunked=None, expect100=False, raise_for_status=None,\ read_until_eof=True, \ proxy=None, proxy_auth=None,\ timeout=sentinel, ssl=True, \ server_hostname=None, \ proxy_headers=None, \ trace_request_ctx=None, \ middlewares=None, \ read_bufsize=None, \ auto_decompress=None, \ max_line_size=None, \ max_field_size=None, \ max_headers=None) :async: :noindexentry: Performs an asynchronous HTTP request. Returns a response object that should be used as an async context manager. :param str method: HTTP method :param url: Request URL, :class:`~yarl.URL` or :class:`str` that will be encoded with :class:`~yarl.URL` (see :class:`~yarl.URL` to skip encoding). :param params: Mapping, iterable of tuple of *key*/*value* pairs or string to be sent as parameters in the query string of the new request. Ignored for subsequent redirected requests (optional) Allowed values are: - :class:`collections.abc.Mapping` e.g. :class:`dict`, :class:`multidict.MultiDict` or :class:`multidict.MultiDictProxy` - :class:`collections.abc.Iterable` e.g. :class:`tuple` or :class:`list` - :class:`str` with preferably url-encoded content (**Warning:** content will not be encoded by *aiohttp*) :param data: The data to send in the body of the request. This can be a :class:`FormData` object or anything that can be passed into :class:`FormData`, e.g. a dictionary, bytes, or file-like object. (optional) :param json: Any json compatible python object (optional). *json* and *data* parameters could not be used at the same time. :param dict cookies: HTTP Cookies to send with the request (optional) Global session cookies and the explicitly set cookies will be merged when sending the request. .. versionadded:: 3.5 :param dict headers: HTTP Headers to send with the request (optional) :param skip_auto_headers: set of headers for which autogeneration should be skipped. *aiohttp* autogenerates headers like ``User-Agent`` or ``Content-Type`` if these headers are not explicitly passed. Using ``skip_auto_headers`` parameter allows to skip that generation. Iterable of :class:`str` or :class:`~multidict.istr` (optional) :param aiohttp.BasicAuth auth: an object that represents HTTP Basic Authorization (optional) :param bool allow_redirects: Whether to process redirects or not. When ``True``, redirects are followed (up to ``max_redirects`` times) and logged into :attr:`ClientResponse.history` and ``trace_configs``. When ``False``, the original response is returned. ``True`` by default (optional). :param int max_redirects: Maximum number of redirects to follow. :exc:`TooManyRedirects` is raised if the number is exceeded. Ignored when ``allow_redirects=False``. ``10`` by default. :param bool compress: Set to ``True`` if request has to be compressed with deflate encoding. If `compress` can not be combined with a *Content-Encoding* and *Content-Length* headers. ``None`` by default (optional). :param int chunked: Enable chunked transfer encoding. It is up to the developer to decide how to chunk data streams. If chunking is enabled, aiohttp encodes the provided chunks in the "Transfer-encoding: chunked" format. If *chunked* is set, then the *Transfer-encoding* and *content-length* headers are disallowed. ``None`` by default (optional). :param bool expect100: Expect 100-continue response from server. ``False`` by default (optional). :param bool raise_for_status: Automatically call :meth:`ClientResponse.raise_for_status` for response if set to ``True``. If set to ``None`` value from ``ClientSession`` will be used. ``None`` by default (optional). .. versionadded:: 3.4 :param bool read_until_eof: Read response until EOF if response does not have Content-Length header. ``True`` by default (optional). :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP Basic Authorization (optional) :param int timeout: override the session's timeout. .. versionchanged:: 3.3 The parameter is :class:`ClientTimeout` instance, :class:`float` is still supported for sake of backward compatibility. If :class:`float` is passed it is a *total* timeout (in seconds). :param ssl: SSL validation mode. ``True`` for default SSL check (:func:`ssl.create_default_context` is used), ``False`` for skip SSL certificate validation, :class:`aiohttp.Fingerprint` for fingerprint validation, :class:`ssl.SSLContext` for custom SSL certificate validation. Supersedes *verify_ssl*, *ssl_context* and *fingerprint* parameters. .. versionadded:: 3.0 :param str server_hostname: Sets or overrides the host name that the target server's certificate will be matched against. See :py:meth:`asyncio.loop.create_connection` for more information. .. versionadded:: 3.9 :param collections.abc.Mapping proxy_headers: HTTP headers to send to the proxy if the parameter proxy has been provided. .. versionadded:: 2.3 :param trace_request_ctx: Object used to give as a kw param for each new :class:`TraceConfig` object instantiated, used to give information to the tracers that is only available at request time. .. versionadded:: 3.0 :param middlewares: A sequence of middleware instances to apply to this request only. Each middleware must match the :type:`ClientMiddlewareType` signature. ``None`` by default which uses session middlewares. See :ref:`aiohttp-client-middleware` for more information. .. versionadded:: 3.12 :param int read_bufsize: Size of the read buffer (:attr:`ClientResponse.content`). ``None`` by default, it means that the session global value is used. .. versionadded:: 3.7 :param bool auto_decompress: Automatically decompress response body. Overrides :attr:`ClientSession.auto_decompress`. May be used to enable/disable auto decompression on a per-request basis. :param int max_line_size: Maximum allowed size of lines in responses. :param int max_field_size: Maximum allowed size of header name and value combined in responses. :param int max_headers: Maximum number of headers and trailers combined in responses. :return ClientResponse: a :class:`client response ` object. .. method:: get(url, *, allow_redirects=True, **kwargs) :async: Perform a ``GET`` request. Returns an async context manager. In order to modify inner :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` :param bool allow_redirects: Whether to process redirects or not. When ``True``, redirects are followed and logged into :attr:`ClientResponse.history`. When ``False``, the original response is returned. ``True`` by default (optional). :return ClientResponse: a :class:`client response ` object. .. method:: post(url, *, data=None, **kwargs) :async: Perform a ``POST`` request. Returns an async context manager. In order to modify inner :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` :param data: Data to send in the body of the request; see :meth:`request` for details (optional) :return ClientResponse: a :class:`client response ` object. .. method:: put(url, *, data=None, **kwargs) :async: Perform a ``PUT`` request. Returns an async context manager. In order to modify inner :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` :param data: Data to send in the body of the request; see :meth:`request` for details (optional) :return ClientResponse: a :class:`client response ` object. .. method:: delete(url, **kwargs) :async: Perform a ``DELETE`` request. Returns an async context manager. In order to modify inner :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` :return ClientResponse: a :class:`client response ` object. .. method:: head(url, *, allow_redirects=False, **kwargs) :async: Perform a ``HEAD`` request. Returns an async context manager. In order to modify inner :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` :param bool allow_redirects: Whether to process redirects or not. When ``True``, redirects are followed and logged into :attr:`ClientResponse.history`. When ``False``, the original response is returned. ``False`` by default (optional). :return ClientResponse: a :class:`client response ` object. .. method:: options(url, *, allow_redirects=True, **kwargs) :async: Perform an ``OPTIONS`` request. Returns an async context manager. In order to modify inner :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` :param bool allow_redirects: Whether to process redirects or not. When ``True``, redirects are followed and logged into :attr:`ClientResponse.history`. When ``False``, the original response is returned. ``True`` by default (optional). :return ClientResponse: a :class:`client response ` object. .. method:: patch(url, *, data=None, **kwargs) :async: Perform a ``PATCH`` request. Returns an async context manager. In order to modify inner :meth:`request` parameters, provide `kwargs`. :param url: Request URL, :class:`str` or :class:`~yarl.URL` :param data: Data to send in the body of the request; see :meth:`request` for details (optional) :return ClientResponse: a :class:`client response ` object. .. method:: ws_connect(url, *, method='GET', \ protocols=(), \ timeout=sentinel,\ auth=None,\ autoclose=True,\ autoping=True,\ heartbeat=None,\ origin=None, \ params=None, \ headers=None, \ proxy=None, proxy_auth=None, ssl=True, \ verify_ssl=None, fingerprint=None, \ ssl_context=None, proxy_headers=None, \ compress=0, max_msg_size=4194304, \ decode_text=True) :async: Create a websocket connection. Returns a :class:`ClientWebSocketResponse` async context manager object. :param url: Websocket server url, :class:`~yarl.URL` or :class:`str` that will be encoded with :class:`~yarl.URL` (see :class:`~yarl.URL` to skip encoding). :param tuple protocols: Websocket protocols :param timeout: a :class:`ClientWSTimeout` timeout for websocket. By default, the value `ClientWSTimeout(ws_receive=None, ws_close=10.0)` is used (``10.0`` seconds for the websocket to close). ``None`` means no timeout will be used. :param aiohttp.BasicAuth auth: an object that represents HTTP Basic Authorization (optional) :param bool autoclose: Automatically close websocket connection on close message from server. If *autoclose* is False then close procedure has to be handled manually. ``True`` by default :param bool autoping: automatically send *pong* on *ping* message from server. ``True`` by default :param float heartbeat: Send *ping* message every *heartbeat* seconds and wait *pong* response, if *pong* response is not received then close connection. The timer is reset on any inbound data reception (coalesced per event loop iteration). (optional) :param str origin: Origin header to send to server(optional) :param params: Mapping, iterable of tuple of *key*/*value* pairs or string to be sent as parameters in the query string of the new request. Ignored for subsequent redirected requests (optional) Allowed values are: - :class:`collections.abc.Mapping` e.g. :class:`dict`, :class:`multidict.MultiDict` or :class:`multidict.MultiDictProxy` - :class:`collections.abc.Iterable` e.g. :class:`tuple` or :class:`list` - :class:`str` with preferably url-encoded content (**Warning:** content will not be encoded by *aiohttp*) :param dict headers: HTTP Headers to send with the request (optional) :param str proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP Basic Authorization (optional) :param ssl: SSL validation mode. ``True`` for default SSL check (:func:`ssl.create_default_context` is used), ``False`` for skip SSL certificate validation, :class:`aiohttp.Fingerprint` for fingerprint validation, :class:`ssl.SSLContext` for custom SSL certificate validation. Supersedes *verify_ssl*, *ssl_context* and *fingerprint* parameters. .. versionadded:: 3.0 :param bool verify_ssl: Perform SSL certificate validation for *HTTPS* requests (enabled by default). May be disabled to skip validation for sites with invalid certificates. .. versionadded:: 2.3 .. deprecated:: 3.0 Use ``ssl=False`` :param bytes fingerprint: Pass the SHA256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. Useful for `certificate pinning `_. Note: use of MD5 or SHA1 digests is insecure and deprecated. .. versionadded:: 2.3 .. deprecated:: 3.0 Use ``ssl=aiohttp.Fingerprint(digest)`` :param ssl.SSLContext ssl_context: ssl context used for processing *HTTPS* requests (optional). *ssl_context* may be used for configuring certification authority channel, supported SSL options etc. .. versionadded:: 2.3 .. deprecated:: 3.0 Use ``ssl=ssl_context`` :param dict proxy_headers: HTTP headers to send to the proxy if the parameter proxy has been provided. .. versionadded:: 2.3 :param int compress: Enable Per-Message Compress Extension support. 0 for disable, 9 to 15 for window bit support. Default value is 0. .. versionadded:: 2.3 :param int max_msg_size: maximum size of read websocket message, 4 MB by default. To disable the size limit use ``0``. .. versionadded:: 3.3 :param str method: HTTP method to establish WebSocket connection, ``'GET'`` by default. .. versionadded:: 3.5 :param bool decode_text: If ``True`` (default), TEXT messages are decoded to strings. If ``False``, TEXT messages are returned as raw bytes, which can improve performance when using JSON parsers like ``orjson`` that accept bytes directly. .. versionadded:: 3.14 .. method:: close() :async: Close underlying connector. Release all acquired resources. .. method:: detach() Detach connector from session without closing the former. Session is switched to closed state anyway. Basic API --------- While we encourage :class:`ClientSession` usage we also provide simple coroutines for making HTTP requests. Basic API is good for performing simple HTTP requests without keepaliving, cookies and complex connection stuff like properly configured SSL certification chaining. .. function:: request(method, url, *, params=None, data=None, \ json=None,\ cookies=None, headers=None, skip_auto_headers=None, auth=None, \ allow_redirects=True, max_redirects=10, \ compress=False, chunked=None, expect100=False, raise_for_status=None, \ read_until_eof=True, \ proxy=None, proxy_auth=None, \ timeout=sentinel, ssl=True, \ server_hostname=None, \ proxy_headers=None, \ trace_request_ctx=None, \ read_bufsize=None, \ auto_decompress=None, \ max_line_size=None, \ max_field_size=None, \ max_headers=None, \ version=aiohttp.HttpVersion11, \ connector=None) :canonical: aiohttp.client.request :async: Asynchronous context manager for performing an asynchronous HTTP request. Returns a :class:`ClientResponse` response object. Use as an async context manager. :param str method: HTTP method :param url: Request URL, :class:`~yarl.URL` or :class:`str` that will be encoded with :class:`~yarl.URL` (see :class:`~yarl.URL` to skip encoding). :param params: Mapping, iterable of tuple of *key*/*value* pairs or string to be sent as parameters in the query string of the new request. Ignored for subsequent redirected requests (optional) Allowed values are: - :class:`collections.abc.Mapping` e.g. :class:`dict`, :class:`multidict.MultiDict` or :class:`multidict.MultiDictProxy` - :class:`collections.abc.Iterable` e.g. :class:`tuple` or :class:`list` - :class:`str` with preferably url-encoded content (**Warning:** content will not be encoded by *aiohttp*) :param data: The data to send in the body of the request. This can be a :class:`FormData` object or anything that can be passed into :class:`FormData`, e.g. a dictionary, bytes, or file-like object. (optional) :param json: Any json compatible python object (optional). *json* and *data* parameters could not be used at the same time. :param dict cookies: HTTP Cookies to send with the request (optional) :param dict headers: HTTP Headers to send with the request (optional) :param skip_auto_headers: set of headers for which autogeneration should be skipped. *aiohttp* autogenerates headers like ``User-Agent`` or ``Content-Type`` if these headers are not explicitly passed. Using ``skip_auto_headers`` parameter allows to skip that generation. Iterable of :class:`str` or :class:`~multidict.istr` (optional) :param aiohttp.BasicAuth auth: an object that represents HTTP Basic Authorization (optional) :param bool allow_redirects: Whether to process redirects or not. When ``True``, redirects are followed (up to ``max_redirects`` times) and logged into :attr:`ClientResponse.history` and ``trace_configs``. When ``False``, the original response is returned. ``True`` by default (optional). :param int max_redirects: Maximum number of redirects to follow. :exc:`TooManyRedirects` is raised if the number is exceeded. Ignored when ``allow_redirects=False``. ``10`` by default. :param bool compress: Set to ``True`` if request has to be compressed with deflate encoding. If `compress` can not be combined with a *Content-Encoding* and *Content-Length* headers. ``None`` by default (optional). :param int chunked: Enables chunked transfer encoding. It is up to the developer to decide how to chunk data streams. If chunking is enabled, aiohttp encodes the provided chunks in the "Transfer-encoding: chunked" format. If *chunked* is set, then the *Transfer-encoding* and *content-length* headers are disallowed. ``None`` by default (optional). :param bool expect100: Expect 100-continue response from server. ``False`` by default (optional). :param bool raise_for_status: Automatically call :meth:`ClientResponse.raise_for_status` for response if set to ``True``. If set to ``None`` value from ``ClientSession`` will be used. ``None`` by default (optional). .. versionadded:: 3.4 :param bool read_until_eof: Read response until EOF if response does not have Content-Length header. ``True`` by default (optional). :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP Basic Authorization (optional) :param timeout: a :class:`ClientTimeout` settings structure, 300 seconds (5min) total timeout, 30 seconds socket connect timeout by default. :param ssl: SSL validation mode. ``True`` for default SSL check (:func:`ssl.create_default_context` is used), ``False`` for skip SSL certificate validation, :class:`aiohttp.Fingerprint` for fingerprint validation, :class:`ssl.SSLContext` for custom SSL certificate validation. Supersedes *verify_ssl*, *ssl_context* and *fingerprint* parameters. :param str server_hostname: Sets or overrides the host name that the target server's certificate will be matched against. See :py:meth:`asyncio.loop.create_connection` for more information. :param collections.abc.Mapping proxy_headers: HTTP headers to send to the proxy if the parameter proxy has been provided. :param trace_request_ctx: Object used to give as a kw param for each new :class:`TraceConfig` object instantiated, used to give information to the tracers that is only available at request time. :param int read_bufsize: Size of the read buffer (:attr:`ClientResponse.content`). ``None`` by default, it means that the session global value is used. .. versionadded:: 3.7 :param bool auto_decompress: Automatically decompress response body. May be used to enable/disable auto decompression on a per-request basis. :param int max_line_size: Maximum allowed size of lines in responses. :param int max_field_size: Maximum allowed size of header name and value combined in responses. :param int max_headers: Maximum number of headers and trailers combined in responses. :param aiohttp.protocol.HttpVersion version: Request HTTP version, ``HTTP 1.1`` by default. (optional) :param aiohttp.BaseConnector connector: BaseConnector sub-class instance to support connection pooling. (optional) :return ClientResponse: a :class:`client response ` object. Usage:: import aiohttp async def fetch(): async with aiohttp.request('GET', 'http://python.org/') as resp: assert resp.status == 200 print(await resp.text()) .. _aiohttp-client-reference-connectors: Connectors ---------- Connectors are transports for aiohttp client API. There are standard connectors: 1. :class:`TCPConnector` for regular *TCP sockets* (both *HTTP* and *HTTPS* schemes supported). 2. :class:`UnixConnector` for connecting via UNIX socket (it's used mostly for testing purposes). All connector classes should be derived from :class:`BaseConnector`. By default all *connectors* support *keep-alive connections* (behavior is controlled by *force_close* constructor's parameter). .. class:: BaseConnector(*, keepalive_timeout=15, \ force_close=False, limit=100, limit_per_host=0, \ enable_cleanup_closed=False, loop=None) :canonical: aiohttp.connector.BaseConnector Base class for all connectors. :param float keepalive_timeout: timeout for connection reusing after releasing (optional). Values ``0``. For disabling *keep-alive* feature use ``force_close=True`` flag. :param int limit: total number simultaneous connections. If *limit* is ``0`` the connector has no limit (default: 100). :param int limit_per_host: limit simultaneous connections to the same endpoint. Endpoints are the same if they are have equal ``(host, port, is_ssl)`` triple. If *limit* is ``0`` the connector has no limit (default: 0). :param bool force_close: close underlying sockets after connection releasing (optional). :param bool enable_cleanup_closed: some SSL servers do not properly complete SSL shutdown process, in that case asyncio leaks SSL connections. If this parameter is set to True, aiohttp additionally aborts underlining transport after 2 seconds. It is off by default. For Python version 3.12.7+, or 3.13.1 and later, this parameter is ignored because the asyncio SSL connection leak is fixed in these versions of Python. :param loop: :ref:`event loop` used for handling connections. If param is ``None``, :func:`asyncio.get_event_loop` is used for getting default event loop. .. deprecated:: 2.0 .. attribute:: closed Read-only property, ``True`` if connector is closed. .. attribute:: force_close Read-only property, ``True`` if connector should ultimately close connections on releasing. .. attribute:: limit The total number for simultaneous connections. If limit is 0 the connector has no limit. The default limit size is 100. .. attribute:: limit_per_host The limit for simultaneous connections to the same endpoint. Endpoints are the same if they are have equal ``(host, port, is_ssl)`` triple. If *limit_per_host* is ``0`` the connector has no limit per host. Read-only property. .. method:: close() :async: Close all opened connections. .. method:: connect(request) :async: Get a free connection from pool or create new one if connection is absent in the pool. The call may be paused if :attr:`limit` is exhausted until used connections returns to pool. :param aiohttp.ClientRequest request: request object which is connection initiator. :return: :class:`Connection` object. .. method:: _create_connection(req) :async: Abstract method for actual connection establishing, should be overridden in subclasses. .. py:class:: AddrInfoType Refer to :py:data:`aiohappyeyeballs.AddrInfoType` for more info. .. warning:: Be sure to use ``aiohttp.AddrInfoType`` rather than ``aiohappyeyeballs.AddrInfoType`` to avoid import breakage, as it is likely to be removed from :mod:`aiohappyeyeballs` in the future. .. py:class:: SocketFactoryType Refer to :py:data:`aiohappyeyeballs.SocketFactoryType` for more info. .. warning:: Be sure to use ``aiohttp.SocketFactoryType`` rather than ``aiohappyeyeballs.SocketFactoryType`` to avoid import breakage, as it is likely to be removed from :mod:`aiohappyeyeballs` in the future. .. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \ use_dns_cache=True, ttl_dns_cache=10, \ family=0, ssl_context=None, local_addr=None, \ resolver=None, keepalive_timeout=sentinel, \ force_close=False, limit=100, limit_per_host=0, \ enable_cleanup_closed=False, timeout_ceil_threshold=5, \ happy_eyeballs_delay=0.25, interleave=None, loop=None, \ socket_factory=None, ssl_shutdown_timeout=0) :canonical: aiohttp.connector.TCPConnector Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. The most common transport. When you don't know what connector type to use, use a :class:`TCPConnector` instance. :class:`TCPConnector` inherits from :class:`BaseConnector`. Constructor accepts all parameters suitable for :class:`BaseConnector` plus several TCP-specific ones: :param ssl: SSL validation mode. ``True`` for default SSL check (:func:`ssl.create_default_context` is used), ``False`` for skip SSL certificate validation, :class:`aiohttp.Fingerprint` for fingerprint validation, :class:`ssl.SSLContext` for custom SSL certificate validation. Supersedes *verify_ssl*, *ssl_context* and *fingerprint* parameters. .. versionadded:: 3.0 :param bool verify_ssl: perform SSL certificate validation for *HTTPS* requests (enabled by default). May be disabled to skip validation for sites with invalid certificates. .. deprecated:: 2.3 Pass *verify_ssl* to ``ClientSession.get()`` etc. :param bytes fingerprint: pass the SHA256 digest of the expected certificate in DER format to verify that the certificate the server presents matches. Useful for `certificate pinning `_. Note: use of MD5 or SHA1 digests is insecure and deprecated. .. deprecated:: 2.3 Pass *verify_ssl* to ``ClientSession.get()`` etc. :param bool use_dns_cache: use internal cache for DNS lookups, ``True`` by default. Enabling an option *may* speedup connection establishing a bit but may introduce some *side effects* also. :param int ttl_dns_cache: expire after some seconds the DNS entries, ``None`` means cached forever. By default 10 seconds (optional). In some environments the IP addresses related to a specific HOST can change after a specific time. Use this option to keep the DNS cache updated refreshing each entry after N seconds. :param int limit: total number simultaneous connections. If *limit* is ``0`` the connector has no limit (default: 100). :param int limit_per_host: limit simultaneous connections to the same endpoint. Endpoints are the same if they are have equal ``(host, port, is_ssl)`` triple. If *limit* is ``0`` the connector has no limit (default: 0). :param aiohttp.abc.AbstractResolver resolver: custom resolver instance to use. ``aiohttp.DefaultResolver`` by default (asynchronous if ``aiodns>=1.1`` is installed). Custom resolvers allow to resolve hostnames differently than the way the host is configured. The resolver is ``aiohttp.ThreadedResolver`` by default, asynchronous version is pretty robust but might fail in very rare cases. :param int family: TCP socket family, both IPv4 and IPv6 by default. For *IPv4* only use :data:`socket.AF_INET`, for *IPv6* only -- :data:`socket.AF_INET6`. *family* is ``0`` by default, that means both IPv4 and IPv6 are accepted. To specify only concrete version please pass :data:`socket.AF_INET` or :data:`socket.AF_INET6` explicitly. :param ssl.SSLContext ssl_context: SSL context used for processing *HTTPS* requests (optional). *ssl_context* may be used for configuring certification authority channel, supported SSL options etc. :param tuple local_addr: tuple of ``(local_host, local_port)`` used to bind socket locally if specified. :param bool force_close: close underlying sockets after connection releasing (optional). :param bool enable_cleanup_closed: Some ssl servers do not properly complete SSL shutdown process, in that case asyncio leaks SSL connections. If this parameter is set to True, aiohttp additionally aborts underlining transport after 2 seconds. It is off by default. :param float happy_eyeballs_delay: The amount of time in seconds to wait for a connection attempt to complete, before starting the next attempt in parallel. This is the “Connection Attempt Delay” as defined in RFC 8305. To disable Happy Eyeballs, set this to ``None``. The default value recommended by the RFC is 0.25 (250 milliseconds). .. versionadded:: 3.10 :param int interleave: controls address reordering when a host name resolves to multiple IP addresses. If ``0`` or unspecified, no reordering is done, and addresses are tried in the order returned by the resolver. If a positive integer is specified, the addresses are interleaved by address family, and the given integer is interpreted as “First Address Family Count” as defined in RFC 8305. The default is ``0`` if happy_eyeballs_delay is not specified, and ``1`` if it is. .. versionadded:: 3.10 :param SocketFactoryType socket_factory: This function takes an :py:data:`AddrInfoType` and is used in lieu of :py:func:`socket.socket` when creating TCP connections. .. versionadded:: 3.12 :param float ssl_shutdown_timeout: **(DEPRECATED)** This parameter is deprecated and will be removed in aiohttp 4.0. Grace period for SSL shutdown on TLS connections when the connector is closed (``0`` seconds by default). By default (``0``), SSL connections are aborted immediately when the connector is closed, without performing the shutdown handshake. During normal operation, SSL connections use Python's default SSL shutdown behavior. Setting this to a positive value (e.g., ``0.1``) will perform a graceful shutdown when closing the connector, notifying the remote server which can help prevent "connection reset" errors at the cost of additional cleanup time. Note: On Python versions prior to 3.11, only a value of ``0`` is supported; other values will trigger a warning. .. versionadded:: 3.12.5 .. versionchanged:: 3.12.11 Changed default from ``0.1`` to ``0`` to abort SSL connections immediately when the connector is closed. Added support for ``ssl_shutdown_timeout=0`` on all Python versions. A :exc:`RuntimeWarning` is issued when non-zero values are passed on Python < 3.11. .. deprecated:: 3.12.11 This parameter is deprecated and will be removed in aiohttp 4.0. .. attribute:: family *TCP* socket family e.g. :data:`socket.AF_INET` or :data:`socket.AF_INET6` Read-only property. .. attribute:: dns_cache Use quick lookup in internal *DNS* cache for host names if ``True``. Read-only :class:`bool` property. .. attribute:: cached_hosts The cache of resolved hosts if :attr:`dns_cache` is enabled. Read-only :class:`types.MappingProxyType` property. .. method:: clear_dns_cache(self, host=None, port=None) Clear internal *DNS* cache. Remove specific entry if both *host* and *port* are specified, clear all cache otherwise. .. class:: UnixConnector(path, *, conn_timeout=None, \ keepalive_timeout=30, limit=100, \ force_close=False, loop=None) :canonical: aiohttp.connector.UnixConnector Unix socket connector. Use :class:`UnixConnector` for sending *HTTP/HTTPS* requests through *UNIX Sockets* as underlying transport. UNIX sockets are handy for writing tests and making very fast connections between processes on the same host. :class:`UnixConnector` is inherited from :class:`BaseConnector`. Usage:: conn = UnixConnector(path='/path/to/socket') session = ClientSession(connector=conn) async with session.get('http://python.org') as resp: ... Constructor accepts all parameters suitable for :class:`BaseConnector` plus UNIX-specific one: :param str path: Unix socket path .. attribute:: path Path to *UNIX socket*, read-only :class:`str` property. .. class:: Connection :canonical: aiohttp.connector.Connection Encapsulates single connection in connector object. End user should never create :class:`Connection` instances manually but get it by :meth:`BaseConnector.connect` coroutine. .. attribute:: closed :class:`bool` read-only property, ``True`` if connection was closed, released or detached. .. attribute:: loop Event loop used for connection .. deprecated:: 3.5 .. attribute:: transport Connection transport .. method:: close() Close connection with forcibly closing underlying socket. .. method:: release() Release connection back to connector. Underlying socket is not closed, the connection may be reused later if timeout (30 seconds by default) for connection was not expired. Response object --------------- .. class:: ClientResponse :canonical: aiohttp.client_reqrep.ClientResponse Client response returned by :meth:`aiohttp.ClientSession.request` and family. User never creates the instance of ClientResponse class but gets it from API calls. :class:`ClientResponse` supports async context manager protocol, e.g.:: resp = await client_session.get(url) async with resp: assert resp.status == 200 After exiting from ``async with`` block response object will be *released* (see :meth:`release` method). .. attribute:: version Response's version, :class:`~aiohttp.protocol.HttpVersion` instance. .. attribute:: status HTTP status code of response (:class:`int`), e.g. ``200``. .. attribute:: reason HTTP status reason of response (:class:`str`), e.g. ``"OK"``. .. attribute:: ok Boolean representation of HTTP status code (:class:`bool`). ``True`` if ``status`` is less than ``400``; otherwise, ``False``. .. attribute:: method Request's method (:class:`str`). .. attribute:: url URL of request (:class:`~yarl.URL`). .. attribute:: real_url Unmodified URL of request with URL fragment unstripped (:class:`~yarl.URL`). .. versionadded:: 3.2 .. attribute:: connection :class:`Connection` used for handling response. .. attribute:: content Payload stream, which contains response's BODY (:class:`StreamReader`). It supports various reading methods depending on the expected format. When chunked transfer encoding is used by the server, allows retrieving the actual http chunks. Reading from the stream may raise :exc:`aiohttp.ClientPayloadError` if the response object is closed before response receives all data or in case if any transfer encoding related errors like malformed chunked encoding of broken compression data. .. attribute:: cookies HTTP cookies of response (*Set-Cookie* HTTP header, :class:`~http.cookies.SimpleCookie`). .. note:: Since :class:`~http.cookies.SimpleCookie` uses cookie name as the key, cookies with the same name but different domains or paths will be overwritten. Only the last cookie with a given name will be accessible via this attribute. To access all cookies, including duplicates with the same name, use :meth:`response.headers.getall('Set-Cookie') `. The session's cookie jar will correctly store all cookies, even if they are not accessible via this attribute. .. attribute:: headers A case-insensitive multidict proxy with HTTP headers of response, :class:`~multidict.CIMultiDictProxy`. .. attribute:: raw_headers Unmodified HTTP headers of response as unconverted bytes, a sequence of ``(key, value)`` pairs. .. attribute:: links Link HTTP header parsed into a :class:`~multidict.MultiDictProxy`. For each link, key is link param `rel` when it exists, or link url as :class:`str` otherwise, and value is :class:`~multidict.MultiDictProxy` of link params and url at key `url` as :class:`~yarl.URL` instance. .. versionadded:: 3.2 .. attribute:: content_type Read-only property with *content* part of *Content-Type* header. .. note:: Returns ``'application/octet-stream'`` if no Content-Type header is present or the value contains invalid syntax according to :rfc:`9110`. To see the original header check ``resp.headers["Content-Type"]``. To make sure Content-Type header is not present in the server reply, use :attr:`headers` or :attr:`raw_headers`, e.g. ``'Content-Type' not in resp.headers``. .. attribute:: charset Read-only property that specifies the *encoding* for the request's BODY. The value is parsed from the *Content-Type* HTTP header. Returns :class:`str` like ``'utf-8'`` or ``None`` if no *Content-Type* header present in HTTP headers or it has no charset information. .. attribute:: content_disposition Read-only property that specified the *Content-Disposition* HTTP header. Instance of :class:`ContentDisposition` or ``None`` if no *Content-Disposition* header present in HTTP headers. .. attribute:: history A :class:`~collections.abc.Sequence` of :class:`ClientResponse` objects of preceding requests (earliest request first) if there were redirects, an empty sequence otherwise. .. method:: close() Close response and underlying connection. For :term:`keep-alive` support see :meth:`release`. .. method:: read() :async: Read the whole response's body as :class:`bytes`. Close underlying connection if data reading gets an error, release connection otherwise. Raise an :exc:`aiohttp.ClientResponseError` if the data can't be read. :return bytes: read *BODY*. .. seealso:: :meth:`close`, :meth:`release`. .. method:: release() It is not required to call `release` on the response object. When the client fully receives the payload, the underlying connection automatically returns back to pool. If the payload is not fully read, the connection is closed .. method:: raise_for_status() Raise an :exc:`aiohttp.ClientResponseError` if the response status is 400 or higher. Do nothing for success responses (less than 400). .. method:: text(encoding=None) :async: Read response's body and return decoded :class:`str` using specified *encoding* parameter. If *encoding* is ``None`` content encoding is determined from the Content-Type header, or using the ``fallback_charset_resolver`` function. Close underlying connection if data reading gets an error, release connection otherwise. :param str encoding: text encoding used for *BODY* decoding, or ``None`` for encoding autodetection (default). :raises: :exc:`UnicodeDecodeError` if decoding fails. See also :meth:`get_encoding`. :return str: decoded *BODY* .. method:: json(*, encoding=None, loads=json.loads, \ content_type='application/json') :async: Read response's body as *JSON*, return :class:`dict` using specified *encoding* and *loader*. If data is not still available a ``read`` call will be done. If response's `content-type` does not match `content_type` parameter :exc:`aiohttp.ContentTypeError` get raised. To disable content type check pass ``None`` value. :param str encoding: text encoding used for *BODY* decoding, or ``None`` for encoding autodetection (default). By the standard JSON encoding should be ``UTF-8`` but practice beats purity: some servers return non-UTF responses. Autodetection works pretty fine anyway. :param collections.abc.Callable loads: :term:`callable` used for loading *JSON* data, :func:`json.loads` by default. :param str content_type: specify response's content-type, if content type does not match raise :exc:`aiohttp.ClientResponseError`. To disable `content-type` check, pass ``None`` as value. (default: `application/json`). :return: *BODY* as *JSON* data parsed by *loads* parameter or ``None`` if *BODY* is empty or contains white-spaces only. .. attribute:: request_info A :class:`typing.NamedTuple` with request URL and headers from :class:`~aiohttp.ClientRequest` object, :class:`aiohttp.RequestInfo` instance. .. method:: get_encoding() Retrieve content encoding using ``charset`` info in ``Content-Type`` HTTP header. If no charset is present or the charset is not understood by Python, the ``fallback_charset_resolver`` function associated with the ``ClientSession`` is called. .. versionadded:: 3.0 ClientWebSocketResponse ----------------------- To connect to a websocket server :func:`aiohttp.ws_connect` or :meth:`aiohttp.ClientSession.ws_connect` coroutines should be used, do not create an instance of class :class:`ClientWebSocketResponse` manually. .. class:: ClientWebSocketResponse() :canonical: aiohttp.client_ws.ClientWebSocketResponse Class for handling client-side websockets. .. attribute:: closed Read-only property, ``True`` if :meth:`close` has been called or :const:`~aiohttp.WSMsgType.CLOSE` message has been received from peer. .. attribute:: protocol Websocket *subprotocol* chosen after :meth:`start` call. May be ``None`` if server and client protocols are not overlapping. .. method:: get_extra_info(name, default=None) Reads optional extra information from the connection's transport. If no value associated with ``name`` is found, ``default`` is returned. See :meth:`asyncio.BaseTransport.get_extra_info` :param str name: The key to look up in the transport extra information. :param default: Default value to be used when no value for ``name`` is found (default is ``None``). .. method:: exception() Returns exception if any occurs or returns None. .. method:: ping(message=b'') :async: Send :const:`~aiohttp.WSMsgType.PING` to peer. :param message: optional payload of *ping* message, :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. .. versionchanged:: 3.0 The method is converted into :term:`coroutine` .. method:: pong(message=b'') :async: Send :const:`~aiohttp.WSMsgType.PONG` to peer. :param message: optional payload of *pong* message, :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. .. versionchanged:: 3.0 The method is converted into :term:`coroutine` .. method:: send_str(data, compress=None) :async: Send *data* to peer as :const:`~aiohttp.WSMsgType.TEXT` message. :param str data: data to send. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :raise TypeError: if data is not :class:`str` .. versionchanged:: 3.0 The method is converted into :term:`coroutine`, *compress* parameter added. .. method:: send_bytes(data, compress=None) :async: Send *data* to peer as :const:`~aiohttp.WSMsgType.BINARY` message. :param data: data to send. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :raise TypeError: if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview`. .. versionchanged:: 3.0 The method is converted into :term:`coroutine`, *compress* parameter added. .. method:: send_json(data, compress=None, *, dumps=json.dumps) :async: Send *data* to peer as JSON string. :param data: data to send. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :param collections.abc.Callable dumps: any :term:`callable` that accepts an object and returns a JSON string (:func:`json.dumps` by default). :raise RuntimeError: if connection is not started or closing :raise ValueError: if data is not serializable object :raise TypeError: if value returned by ``dumps(data)`` is not :class:`str` .. versionchanged:: 3.0 The method is converted into :term:`coroutine`, *compress* parameter added. .. method:: send_json_bytes(data, compress=None, *, dumps) :async: Send *data* to peer as a JSON binary frame using a bytes-returning encoder. :param data: data to send. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :param collections.abc.Callable dumps: any :term:`callable` that accepts an object and returns JSON as :class:`bytes` (e.g. ``orjson.dumps``). :raise RuntimeError: if connection is not started or closing :raise ValueError: if data is not serializable object :raise TypeError: if value returned by ``dumps(data)`` is not :class:`bytes` .. method:: send_frame(message, opcode, compress=None) :async: Send a :const:`~aiohttp.WSMsgType` message *message* to peer. This method is low-level and should be used with caution as it only accepts bytes which must conform to the correct message type for *message*. It is recommended to use the :meth:`send_str`, :meth:`send_bytes` or :meth:`send_json` methods instead of this method. The primary use case for this method is to send bytes that are have already been encoded without having to decode and re-encode them. :param bytes message: message to send. :param ~aiohttp.WSMsgType opcode: opcode of the message. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. .. versionadded:: 3.11 .. method:: close(*, code=WSCloseCode.OK, message=b'') :async: A :ref:`coroutine` that initiates closing handshake by sending :const:`~aiohttp.WSMsgType.CLOSE` message. It waits for close response from server. To add a timeout to `close()` call just wrap the call with `asyncio.wait()` or `asyncio.wait_for()`. :param int code: closing code. See also :class:`~aiohttp.WSCloseCode`. :param message: optional payload of *close* message, :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. .. method:: receive() :async: A :ref:`coroutine` that waits upcoming *data* message from peer and returns it. The coroutine implicitly handles :const:`~aiohttp.WSMsgType.PING`, :const:`~aiohttp.WSMsgType.PONG` and :const:`~aiohttp.WSMsgType.CLOSE` without returning the message. It process *ping-pong game* and performs *closing handshake* internally. :return: :class:`~aiohttp.WSMessage` .. method:: receive_str() :async: A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is :const:`~aiohttp.WSMsgType.TEXT`. :return str: peer's message content. :raise aiohttp.WSMessageTypeError: if message is not :const:`~aiohttp.WSMsgType.TEXT`. .. method:: receive_bytes() :async: A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is :const:`~aiohttp.WSMsgType.BINARY`. :return bytes: peer's message content. :raise aiohttp.WSMessageTypeError: if message is not :const:`~aiohttp.WSMsgType.BINARY`. .. method:: receive_json(*, loads=json.loads) :async: A :ref:`coroutine` that calls :meth:`receive_str` and loads the JSON string to a Python dict. :param collections.abc.Callable loads: any :term:`callable` that accepts :class:`str` and returns :class:`dict` with parsed JSON (:func:`json.loads` by default). :return dict: loaded JSON content :raise TypeError: if message is :const:`~aiohttp.WSMsgType.BINARY`. :raise ValueError: if message is not valid JSON. ClientRequest ------------- .. class:: ClientRequest :canonical: aiohttp.client_reqrep.ClientRequest Represents an HTTP request to be sent by the client. This object encapsulates all the details of an HTTP request before it is sent. It is primarily used within client middleware to inspect or modify requests. .. note:: You typically don't create ``ClientRequest`` instances directly. They are created internally by :class:`ClientSession` methods and passed to middleware. For more information about using middleware, see :ref:`aiohttp-client-middleware`. .. attribute:: body :type: Payload | Literal[b""] The request body payload (defaults to ``b""`` if no body passed). .. danger:: **DO NOT set this attribute directly!** Direct assignment will cause resource leaks. Always use :meth:`update_body` instead: .. code-block:: python # WRONG - This will leak resources! request.body = b"new data" # CORRECT - Use update_body await request.update_body(b"new data") Setting body directly bypasses cleanup of the previous payload, which can leave file handles open, streams unclosed, and buffers unreleased. Additionally, setting body directly must be done from within an event loop and is not thread-safe. Setting body outside of an event loop may raise RuntimeError when closing file-based payloads. .. attribute:: chunked :type: bool | None Whether to use chunked transfer encoding: - ``True``: Use chunked encoding - ``False``: Don't use chunked encoding - ``None``: Automatically determine based on body .. attribute:: compress :type: str | None The compression encoding for the request body. Common values include ``'gzip'`` and ``'deflate'``, but any string value is technically allowed. ``None`` means no compression. .. attribute:: headers :type: multidict.CIMultiDict The HTTP headers that will be sent with the request. This is a case-insensitive multidict that can be modified by middleware. .. code-block:: python # Add or modify headers request.headers['X-Custom-Header'] = 'value' request.headers['User-Agent'] = 'MyApp/1.0' .. attribute:: is_ssl :type: bool ``True`` if the request uses a secure scheme (e.g., HTTPS, WSS), ``False`` otherwise. .. attribute:: method :type: str The HTTP method of the request (e.g., ``'GET'``, ``'POST'``, ``'PUT'``, etc.). .. attribute:: original_url :type: yarl.URL The original URL passed to the request method, including any fragment. This preserves the exact URL as provided by the user. .. attribute:: proxy :type: yarl.URL | None The proxy URL if the request will be sent through a proxy, ``None`` otherwise. .. attribute:: proxy_headers :type: multidict.CIMultiDict | None Headers to be sent to the proxy server (e.g., ``Proxy-Authorization``). Only set when :attr:`proxy` is not ``None``. .. attribute:: response_class :type: type[ClientResponse] The class to use for creating the response object. Defaults to :class:`ClientResponse` but can be customized for special handling. .. attribute:: server_hostname :type: str | None Override the hostname for SSL certificate verification. Useful when connecting through proxies or to IP addresses. .. attribute:: session :type: ClientSession The client session that created this request. Useful for accessing session-level configuration or making additional requests within middleware. .. warning:: Be careful when making requests with the same session inside middleware to avoid infinite recursion. Use ``middlewares=()`` parameter when needed. .. attribute:: ssl :type: ssl.SSLContext | bool | Fingerprint SSL validation configuration for this request: - ``True``: Use default SSL verification - ``False``: Skip SSL verification - :class:`ssl.SSLContext`: Custom SSL context - :class:`Fingerprint`: Verify specific certificate fingerprint .. attribute:: url :type: yarl.URL The target URL of the request with the fragment (``#...``) part stripped. This is the actual URL that will be used for the connection. .. note:: To access the original URL with fragment, use :attr:`original_url`. .. attribute:: version :type: HttpVersion The HTTP version to use for the request (e.g., ``HttpVersion(1, 1)`` for HTTP/1.1). .. method:: update_body(body) Update the request body and close any existing payload to prevent resource leaks. **This is the ONLY correct way to modify a request body.** Never set the :attr:`body` attribute directly. This method is particularly useful in middleware when you need to modify the request body after the request has been created but before it's sent. :param body: The new body content. Can be: - ``bytes``/``bytearray``: Raw binary data - ``str``: Text data (encoded using charset from Content-Type) - :class:`FormData`: Form data encoded as multipart/form-data - :class:`Payload`: A pre-configured payload object - ``AsyncIterable[bytes]``: Async iterable of bytes chunks - File-like object: Will be read and sent as binary data - ``None``: Clears the body .. code-block:: python async def middleware(request, handler): # Modify request body in middleware if request.method == 'POST': # CORRECT: Always use update_body await request.update_body(b'{"modified": true}') # WRONG: Never set body directly! # request.body = b'{"modified": true}' # This leaks resources! # Or add authentication data to form if isinstance(request.body, FormData): form = FormData() # Copy existing fields and add auth token form.add_field('auth_token', 'secret123') await request.update_body(form) return await handler(request) .. note:: This method is async because it may need to close file handles or other resources associated with the previous payload. Always await this method to ensure proper cleanup. .. danger:: **Never set :attr:`ClientRequest.body` directly!** Direct assignment will cause resource leaks. Always use this method instead. Setting the body attribute directly: - Bypasses cleanup of the previous payload - Leaves file handles and streams open - Can cause memory leaks - May result in unexpected behavior with async iterables .. warning:: When updating the body, ensure that the Content-Type header is appropriate for the new body content. The Content-Length header will be updated automatically. When using :class:`FormData` or :class:`Payload` objects, headers are updated automatically, but you may need to set Content-Type manually for raw bytes or text. It is not recommended to change the payload type in middleware. If the body was already set (e.g., as bytes), it's best to keep the same type rather than converting it (e.g., to str) as this may result in unexpected behavior. .. versionadded:: 3.12 Utilities --------- .. class:: ClientTimeout(*, total=None, connect=None, \ sock_connect=None, sock_read=None) :canonical: aiohttp.client.ClientTimeout A data class for client timeout settings. See :ref:`aiohttp-client-timeouts` for usage examples. .. attribute:: total Total number of seconds for the whole request. :class:`float`, ``None`` by default. .. attribute:: connect Maximal number of seconds for acquiring a connection from pool. The time consists connection establishment for a new connection or waiting for a free connection from a pool if pool connection limits are exceeded. For pure socket connection establishment time use :attr:`sock_connect`. :class:`float`, ``None`` by default. .. attribute:: sock_connect Maximal number of seconds for connecting to a peer for a new connection, not given from a pool. See also :attr:`connect`. :class:`float`, ``None`` by default. .. attribute:: sock_read Maximal number of seconds for reading a portion of data from a peer. :class:`float`, ``None`` by default. .. class:: ClientWSTimeout(*, ws_receive=None, ws_close=None) :canonical: aiohttp.client_ws.ClientWSTimeout A data class for websocket client timeout settings. .. attribute:: ws_receive A timeout for websocket to receive a complete message. :class:`float`, ``None`` by default. .. attribute:: ws_close A timeout for the websocket to close. :class:`float`, ``10.0`` by default. .. note:: Timeouts of 5 seconds or more are rounded for scheduling on the next second boundary (an absolute time where microseconds part is zero) for the sake of performance. E.g., assume a timeout is ``10``, absolute time when timeout should expire is ``loop.time() + 5``, and it points to ``12345.67 + 10`` which is equal to ``12355.67``. The absolute time for the timeout cancellation is ``12356``. It leads to grouping all close scheduled timeout expirations to exactly the same time to reduce amount of loop wakeups. .. versionchanged:: 3.7 Rounding to the next seconds boundary is disabled for timeouts smaller than 5 seconds for the sake of easy debugging. In turn, tiny timeouts can lead to significant performance degradation on production environment. .. class:: ETag(name, is_weak=False) :canonical: aiohttp.helpers.ETag Represents `ETag` identifier. .. attribute:: value Value of corresponding etag without quotes. .. attribute:: is_weak Flag indicates that etag is weak (has `W/` prefix). .. versionadded:: 3.8 .. class:: ContentDisposition :canonical: aiohttp.client_reqrep.ContentDisposition A data class to represent the Content-Disposition header, available as :attr:`ClientResponse.content_disposition` attribute. .. attribute:: type A :class:`str` instance. Value of Content-Disposition header itself, e.g. ``attachment``. .. attribute:: filename A :class:`str` instance. Content filename extracted from parameters. May be ``None``. .. attribute:: parameters Read-only mapping contains all parameters. .. class:: RequestInfo() :canonical: aiohttp.client_reqrep.RequestInfo A :class:`typing.NamedTuple` with request URL and headers from :class:`~aiohttp.ClientRequest` object, available as :attr:`ClientResponse.request_info` attribute. .. attribute:: url Requested *url*, :class:`yarl.URL` instance. .. attribute:: method Request HTTP method like ``'GET'`` or ``'POST'``, :class:`str`. .. attribute:: headers HTTP headers for request, :class:`multidict.CIMultiDict` instance. .. attribute:: real_url Requested *url* with URL fragment unstripped, :class:`yarl.URL` instance. .. versionadded:: 3.2 .. class:: BasicAuth(login, password='', encoding='latin1') :canonical: aiohttp.helpers.BasicAuth HTTP basic authentication helper. :param str login: login :param str password: password :param str encoding: encoding (``'latin1'`` by default) Should be used for specifying authorization data in client API, e.g. *auth* parameter for :meth:`ClientSession.request() `. .. classmethod:: decode(auth_header, encoding='latin1') Decode HTTP basic authentication credentials. :param str auth_header: The ``Authorization`` header to decode. :param str encoding: (optional) encoding ('latin1' by default) :return: decoded authentication data, :class:`BasicAuth`. .. classmethod:: from_url(url) Constructed credentials info from url's *user* and *password* parts. :return: credentials data, :class:`BasicAuth` or ``None`` is credentials are not provided. .. versionadded:: 2.3 .. method:: encode() Encode credentials into string suitable for ``Authorization`` header etc. :return: encoded authentication data, :class:`str`. .. class:: DigestAuthMiddleware(login, password, *, preemptive=True) :canonical: aiohttp.client_middleware_digest_auth.DigestAuthMiddleware HTTP digest authentication client middleware. :param str login: login :param str password: password :param bool preemptive: Enable preemptive authentication (default: ``True``) This middleware supports HTTP digest authentication with both `auth` and `auth-int` quality of protection (qop) modes, and a variety of hashing algorithms. It automatically handles the digest authentication handshake by: - Parsing 401 Unauthorized responses with `WWW-Authenticate: Digest` headers - Generating appropriate `Authorization: Digest` headers on retry - Maintaining nonce counts and challenge data per request - When ``preemptive=True``, reusing authentication credentials for subsequent requests to the same protection space (following RFC 7616 Section 3.6) **Preemptive Authentication** By default (``preemptive=True``), the middleware remembers successful authentication challenges and automatically includes the Authorization header in subsequent requests to the same protection space. This behavior: - Improves server efficiency by avoiding extra round trips - Matches how modern web browsers handle digest authentication - Follows the recommendation in RFC 7616 Section 3.6 The server may still respond with a 401 status and ``stale=true`` if the nonce has expired, in which case the middleware will automatically retry with the new nonce. To disable preemptive authentication and require a 401 challenge for every request, set ``preemptive=False``:: # Default behavior - preemptive auth enabled digest_auth_middleware = DigestAuthMiddleware(login="user", password="pass") # Disable preemptive auth - always wait for 401 challenge digest_auth_middleware = DigestAuthMiddleware(login="user", password="pass", preemptive=False) Usage:: digest_auth_middleware = DigestAuthMiddleware(login="user", password="pass") async with ClientSession(middlewares=(digest_auth_middleware,)) as session: async with session.get("http://protected.example.com") as resp: # The middleware automatically handles the digest auth handshake assert resp.status == 200 # Subsequent requests include auth header preemptively async with session.get("http://protected.example.com/other") as resp: assert resp.status == 200 # No 401 round trip needed .. versionadded:: 3.12 .. versionchanged:: 3.12.8 Added ``preemptive`` parameter to enable/disable preemptive authentication. .. class:: CookieJar(*, unsafe=False, quote_cookie=True, treat_as_secure_origin = []) :canonical: aiohttp.cookiejar.CookieJar The cookie jar instance is available as :attr:`ClientSession.cookie_jar`. The jar contains :class:`~http.cookies.Morsel` items for storing internal cookie data. API provides a count of saved cookies:: len(session.cookie_jar) These cookies may be iterated over:: for cookie in session.cookie_jar: print(cookie.key) print(cookie["domain"]) The class implements :class:`collections.abc.Iterable`, :class:`collections.abc.Sized` and :class:`aiohttp.abc.AbstractCookieJar` interfaces. Implements cookie storage adhering to RFC 6265. :param bool unsafe: (optional) Whether to accept cookies from IPs. :param bool quote_cookie: (optional) Whether to quote cookies according to :rfc:`2109`. Some backend systems (not compatible with RFC mentioned above) does not support quoted cookies. .. versionadded:: 3.7 :param treat_as_secure_origin: (optional) Mark origins as secure for cookies marked as Secured. Possible types are Possible types are: - :class:`tuple` or :class:`list` of :class:`str` or :class:`yarl.URL` - :class:`str` - :class:`yarl.URL` .. versionadded:: 3.8 .. method:: update_cookies(cookies, response_url=None) Update cookies returned by server in ``Set-Cookie`` header. :param cookies: a :class:`collections.abc.Mapping` (e.g. :class:`dict`, :class:`~http.cookies.SimpleCookie`) or *iterable* of *pairs* with cookies returned by server's response. :param ~yarl.URL response_url: URL of response, ``None`` for *shared cookies*. Regular cookies are coupled with server's URL and are sent only to this server, shared ones are sent in every client request. .. method:: filter_cookies(request_url) Return jar's cookies acceptable for URL and available in ``Cookie`` header for sending client requests for given URL. :param ~yarl.URL response_url: request's URL for which cookies are asked. :return: :class:`http.cookies.SimpleCookie` with filtered cookies for given URL. .. method:: save(file_path) Write a JSON representation of cookies into the file at provided path. .. versionchanged:: 3.14 Previously used pickle format. Now uses JSON for safe serialization. :param file_path: Path to file where cookies will be serialized, :class:`str` or :class:`pathlib.Path` instance. .. method:: load(file_path) Load cookies from the file at provided path. Tries JSON format first, then falls back to legacy pickle format (using a restricted unpickler that only allows cookie-related types) for backward compatibility with existing cookie files. .. versionchanged:: 3.14 Now loads JSON format by default. Falls back to restricted pickle for files saved by older versions. :param file_path: Path to file from where cookies will be imported, :class:`str` or :class:`pathlib.Path` instance. .. method:: clear(predicate=None) Removes all cookies from the jar if the predicate is ``None``. Otherwise remove only those :class:`~http.cookies.Morsel` that ``predicate(morsel)`` returns ``True``. :param predicate: callable that gets :class:`~http.cookies.Morsel` as a parameter and returns ``True`` if this :class:`~http.cookies.Morsel` must be deleted from the jar. .. versionadded:: 4.0 .. method:: clear_domain(domain) Remove all cookies from the jar that belongs to the specified domain or its subdomains. :param str domain: domain for which cookies must be deleted from the jar. .. versionadded:: 4.0 .. class:: DummyCookieJar(*, loop=None) :canonical: aiohttp.cookiejar.DummyCookieJar Dummy cookie jar which does not store cookies but ignores them. Could be useful e.g. for web crawlers to iterate over Internet without blowing up with saved cookies information. To install dummy cookie jar pass it into session instance:: jar = aiohttp.DummyCookieJar() session = aiohttp.ClientSession(cookie_jar=DummyCookieJar()) .. class:: Fingerprint(digest) :canonical: aiohttp.client_reqrep.Fingerprint Fingerprint helper for checking SSL certificates by *SHA256* digest. :param bytes digest: *SHA256* digest for certificate in DER-encoded binary form (see :meth:`ssl.SSLSocket.getpeercert`). To check fingerprint pass the object into :meth:`ClientSession.get` call, e.g.:: import hashlib with open(path_to_cert, 'rb') as f: digest = hashlib.sha256(f.read()).digest() await session.get(url, ssl=aiohttp.Fingerprint(digest)) .. versionadded:: 3.0 .. function:: set_zlib_backend(lib) :canonical: aiohttp.compression_utils.set_zlib_backend Sets the compression backend for zlib-based operations. This function allows you to override the default zlib backend used internally by passing a module that implements the standard compression interface. The module should implement at minimum the exact interface offered by the latest version of zlib. :param types.ModuleType lib: A module that implements the zlib-compatible compression API. Example usage:: import zlib_ng.zlib_ng as zng import aiohttp aiohttp.set_zlib_backend(zng) .. note:: aiohttp has been tested internally with :mod:`zlib`, :mod:`zlib_ng.zlib_ng`, and :mod:`isal.isal_zlib`. .. versionadded:: 3.12 FormData ^^^^^^^^ A :class:`FormData` object contains the form data and also handles encoding it into a body that is either ``multipart/form-data`` or ``application/x-www-form-urlencoded``. ``multipart/form-data`` is used if at least one field is an :class:`io.IOBase` object or was added with at least one optional argument to :meth:`add_field` (``content_type``, ``filename``, or ``content_transfer_encoding``). Otherwise, ``application/x-www-form-urlencoded`` is used. :class:`FormData` instances are callable and return a :class:`aiohttp.payload.Payload` on being called. .. class:: FormData(fields, quote_fields=True, charset=None) :canonical: aiohttp.formdata.FormData Helper class for multipart/form-data and application/x-www-form-urlencoded body generation. :param fields: A container for the key/value pairs of this form. Possible types are: - :class:`dict` - :class:`tuple` or :class:`list` - :class:`io.IOBase`, e.g. a file-like object - :class:`multidict.MultiDict` or :class:`multidict.MultiDictProxy` If it is a :class:`tuple` or :class:`list`, it must be a valid argument for :meth:`add_fields`. For :class:`dict`, :class:`multidict.MultiDict`, and :class:`multidict.MultiDictProxy`, the keys and values must be valid `name` and `value` arguments to :meth:`add_field`, respectively. .. method:: add_field(name, value, content_type=None, filename=None,\ content_transfer_encoding=None) Add a field to the form. :param str name: Name of the field :param value: Value of the field Possible types are: - :class:`str` - :class:`bytes`, :class:`bytearray`, or :class:`memoryview` - :class:`io.IOBase`, e.g. a file-like object :param str content_type: The field's content-type header (optional) :param str filename: The field's filename (optional) If this is not set and ``value`` is a :class:`bytes`, :class:`bytearray`, or :class:`memoryview` object, the `name` argument is used as the filename unless ``content_transfer_encoding`` is specified. If ``filename`` is not set and ``value`` is an :class:`io.IOBase` object, the filename is extracted from the object if possible. :param str content_transfer_encoding: The field's content-transfer-encoding header (optional) .. method:: add_fields(fields) Add one or more fields to the form. :param fields: An iterable containing: - :class:`io.IOBase`, e.g. a file-like object - :class:`multidict.MultiDict` or :class:`multidict.MultiDictProxy` - :class:`tuple` or :class:`list` of length two, containing a name-value pair Client exceptions ----------------- Exception hierarchy has been significantly modified in version 2.0. aiohttp defines only exceptions that covers connection handling and server response misbehaviors. For developer specific mistakes, aiohttp uses python standard exceptions like :exc:`ValueError` or :exc:`TypeError`. Reading a response content may raise a :exc:`ClientPayloadError` exception. This exception indicates errors specific to the payload encoding. Such as invalid compressed data, malformed chunked-encoded chunks or not enough data that satisfy the content-length header. All exceptions are available as members of *aiohttp* module. .. exception:: ClientError :canonical: aiohttp.client_exceptions.ClientError Base class for all client specific exceptions. Derived from :exc:`Exception` .. class:: ClientPayloadError :canonical: aiohttp.client_exceptions.ClientPayloadError This exception can only be raised while reading the response payload if one of these errors occurs: 1. invalid compression 2. malformed chunked encoding 3. not enough data that satisfy ``Content-Length`` HTTP header. Derived from :exc:`ClientError` .. exception:: InvalidURL :canonical: aiohttp.client_exceptions.InvalidURL URL used for fetching is malformed, e.g. it does not contain host part. Derived from :exc:`ClientError` and :exc:`ValueError` .. attribute:: url Invalid URL, :class:`yarl.URL` instance. .. attribute:: description Invalid URL description, :class:`str` instance or :data:`None`. .. exception:: InvalidUrlClientError :canonical: aiohttp.client_exceptions.InvalidUrlClientError Base class for all errors related to client url. Derived from :exc:`InvalidURL` .. exception:: RedirectClientError :canonical: aiohttp.client_exceptions.RedirectClientError Base class for all errors related to client redirects. Derived from :exc:`ClientError` .. exception:: NonHttpUrlClientError :canonical: aiohttp.client_exceptions.NonHttpUrlClientError Base class for all errors related to non http client urls. Derived from :exc:`ClientError` .. exception:: InvalidUrlRedirectClientError :canonical: aiohttp.client_exceptions.InvalidUrlRedirectClientError Redirect URL is malformed, e.g. it does not contain host part. Derived from :exc:`InvalidUrlClientError` and :exc:`RedirectClientError` .. exception:: NonHttpUrlRedirectClientError :canonical: aiohttp.client_exceptions.NonHttpUrlRedirectClientError Redirect URL does not contain http schema. Derived from :exc:`RedirectClientError` and :exc:`NonHttpUrlClientError` Response errors ^^^^^^^^^^^^^^^ .. exception:: ClientResponseError :canonical: aiohttp.client_exceptions.ClientResponseError These exceptions could happen after we get response from server. Derived from :exc:`ClientError` .. attribute:: request_info Instance of :class:`RequestInfo` object, contains information about request. .. attribute:: status HTTP status code of response (:class:`int`), e.g. ``400``. .. attribute:: message Message of response (:class:`str`), e.g. ``"OK"``. .. attribute:: headers Headers in response, a list of pairs. .. attribute:: history History from failed response, if available, else empty tuple. A :class:`tuple` of :class:`ClientResponse` objects used for handle redirection responses. .. attribute:: code HTTP status code of response (:class:`int`), e.g. ``400``. .. deprecated:: 3.1 .. class:: ContentTypeError :canonical: aiohttp.client_exceptions.ContentTypeError Invalid content type. Derived from :exc:`ClientResponseError` .. versionadded:: 2.3 .. class:: TooManyRedirects :canonical: aiohttp.client_exceptions.TooManyRedirects Client was redirected too many times. Maximum number of redirects can be configured by using parameter ``max_redirects`` in :meth:`request`. Derived from :exc:`ClientResponseError` .. versionadded:: 3.2 .. class:: WSServerHandshakeError :canonical: aiohttp.client_exceptions.WSServerHandshakeError Web socket server response error. Derived from :exc:`ClientResponseError` .. exception:: WSMessageTypeError :canonical: aiohttp.client_exceptions.WSMessageTypeError Received WebSocket message of unexpected type Derived from :exc:`TypeError` Connection errors ^^^^^^^^^^^^^^^^^ .. class:: ClientConnectionError :canonical: aiohttp.client_exceptions.ClientConnectionError These exceptions related to low-level connection problems. Derived from :exc:`ClientError` .. class:: ClientConnectionResetError :canonical: aiohttp.client_exceptions.ClientConnectionResetError Derived from :exc:`ClientConnectionError` and :exc:`ConnectionResetError` .. class:: ClientOSError :canonical: aiohttp.client_exceptions.ClientOSError Subset of connection errors that are initiated by an :exc:`OSError` exception. Derived from :exc:`ClientConnectionError` and :exc:`OSError` .. class:: ClientConnectorError :canonical: aiohttp.client_exceptions.ClientConnectorError Connector related exceptions. Derived from :exc:`ClientOSError` .. class:: ClientConnectorDNSError :canonical: aiohttp.client_exceptions.ClientConnectorDNSError DNS resolution error. Derived from :exc:`ClientConnectorError` .. class:: ClientProxyConnectionError :canonical: aiohttp.client_exceptions.ClientProxyConnectionError Derived from :exc:`ClientConnectorError` .. class:: ClientSSLError :canonical: aiohttp.client_exceptions.ClientSSLError Derived from :exc:`ClientConnectorError` .. class:: ClientConnectorSSLError :canonical: aiohttp.client_exceptions.ClientConnectorSSLError Response ssl error. Derived from :exc:`ClientSSLError` and :exc:`ssl.SSLError` .. class:: ClientConnectorCertificateError :canonical: aiohttp.client_exceptions.ClientConnectorCertificateError Response certificate error. Derived from :exc:`ClientSSLError` and :exc:`ssl.CertificateError` .. class:: UnixClientConnectorError :canonical: aiohttp.client_exceptions.UnixClientConnectorError Derived from :exc:`ClientConnectorError` .. class:: ServerConnectionError :canonical: aiohttp.client_exceptions.ServerConnectionError Derived from :exc:`ClientConnectionError` .. class:: ServerDisconnectedError :canonical: aiohttp.client_exceptions.ServerDisconnectedError Server disconnected. Derived from :exc:`~aiohttp.ServerConnectionError` .. attribute:: message Partially parsed HTTP message (optional). .. class:: ServerFingerprintMismatch :canonical: aiohttp.client_exceptions.ServerFingerprintMismatch Server fingerprint mismatch. Derived from :exc:`ServerConnectionError` .. class:: ServerTimeoutError :canonical: aiohttp.client_exceptions.ServerTimeoutError Server operation timeout: read timeout, etc. To catch all timeouts, including the ``total`` timeout, use :exc:`asyncio.TimeoutError`. Derived from :exc:`ServerConnectionError` and :exc:`asyncio.TimeoutError` .. class:: ConnectionTimeoutError :canonical: aiohttp.client_exceptions.ConnectionTimeoutError Connection timeout on ``connect`` and ``sock_connect`` timeouts. Derived from :exc:`ServerTimeoutError` .. class:: SocketTimeoutError :canonical: aiohttp.client_exceptions.SocketTimeoutError Reading from socket timeout on ``sock_read`` timeout. Derived from :exc:`ServerTimeoutError` Hierarchy of exceptions ^^^^^^^^^^^^^^^^^^^^^^^ * :exc:`ClientError` * :exc:`ClientConnectionError` * :exc:`ClientConnectionResetError` * :exc:`ClientOSError` * :exc:`ClientConnectorError` * :exc:`ClientProxyConnectionError` * :exc:`ClientConnectorDNSError` * :exc:`ClientSSLError` * :exc:`ClientConnectorCertificateError` * :exc:`ClientConnectorSSLError` * :exc:`UnixClientConnectorError` * :exc:`ServerConnectionError` * :exc:`ServerDisconnectedError` * :exc:`ServerFingerprintMismatch` * :exc:`ServerTimeoutError` * :exc:`ConnectionTimeoutError` * :exc:`SocketTimeoutError` * :exc:`ClientPayloadError` * :exc:`ClientResponseError` * :exc:`~aiohttp.ClientHttpProxyError` * :exc:`ContentTypeError` * :exc:`TooManyRedirects` * :exc:`WSServerHandshakeError` * :exc:`InvalidURL` * :exc:`InvalidUrlClientError` * :exc:`InvalidUrlRedirectClientError` * :exc:`NonHttpUrlClientError` * :exc:`NonHttpUrlRedirectClientError` * :exc:`RedirectClientError` * :exc:`InvalidUrlRedirectClientError` * :exc:`NonHttpUrlRedirectClientError` Client Types ------------ .. type:: ClientMiddlewareType Type alias for client middleware functions. Middleware functions must have this signature:: Callable[ [ClientRequest, ClientHandlerType], Awaitable[ClientResponse] ] .. type:: ClientHandlerType Type alias for client request handler functions:: Callable[[ClientRequest], Awaitable[ClientResponse]] ================================================ FILE: docs/code/client_middleware_cookbook.py ================================================ """This is a collection of semi-complete examples that get included into the cookbook page.""" import asyncio import logging import time from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager, suppress from aiohttp import ( ClientError, ClientHandlerType, ClientRequest, ClientResponse, ClientSession, TCPConnector, ) from aiohttp.abc import ResolveResult from aiohttp.tracing import Trace class SSRFError(ClientError): """A request was made to a blacklisted host.""" async def retry_middleware( req: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: for _ in range(3): # Try up to 3 times resp = await handler(req) if resp.ok: return resp return resp # type: ignore[possibly-undefined] async def api_logging_middleware( req: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # We use middlewares=() to avoid infinite recursion. async with req.session.post("/log", data=req.url.host, middlewares=()) as resp: if not resp.ok: logging.warning("Log endpoint failed") return await handler(req) class TokenRefresh401Middleware: def __init__(self, refresh_token: str, access_token: str): self.access_token = access_token self.refresh_token = refresh_token self.lock = asyncio.Lock() async def __call__( self, req: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: for _ in range(2): # Retry at most one time token = self.access_token req.headers["Authorization"] = f"Bearer {token}" resp = await handler(req) if resp.status != 401: return resp async with self.lock: if token != self.access_token: # Already refreshed continue url = "https://api.example/refresh" async with req.session.post(url, data=self.refresh_token) as resp: # Add error handling as needed data = await resp.json() self.access_token = data["access_token"] return resp # type: ignore[possibly-undefined] class TokenRefreshExpiryMiddleware: def __init__(self, refresh_token: str): self.access_token = "" self.expires_at = 0 self.refresh_token = refresh_token self.lock = asyncio.Lock() async def __call__( self, req: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: if self.expires_at <= time.time(): token = self.access_token async with self.lock: if token == self.access_token: # Still not refreshed url = "https://api.example/refresh" async with req.session.post(url, data=self.refresh_token) as resp: # Add error handling as needed data = await resp.json() self.access_token = data["access_token"] self.expires_at = data["expires_at"] req.headers["Authorization"] = f"Bearer {self.access_token}" return await handler(req) async def token_refresh_preemptively_example() -> None: async def set_token(session: ClientSession, event: asyncio.Event) -> None: while True: async with session.post("/refresh") as resp: token = await resp.json() session.headers["Authorization"] = f"Bearer {token['auth']}" event.set() await asyncio.sleep(token["valid_duration"]) @asynccontextmanager async def auto_refresh_client() -> AsyncIterator[ClientSession]: async with ClientSession() as session: ready = asyncio.Event() t = asyncio.create_task(set_token(session, ready)) await ready.wait() yield session t.cancel() with suppress(asyncio.CancelledError): await t async with auto_refresh_client() as sess: ... async def ssrf_middleware( req: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # WARNING: This is a simplified example for demonstration purposes only. # A complete implementation should also check: # - IPv6 loopback (::1) # - Private IP ranges (10.x.x.x, 192.168.x.x, 172.16-31.x.x) # - Link-local addresses (169.254.x.x, fe80::/10) # - Other internal hostnames and aliases if req.url.host in {"127.0.0.1", "localhost"}: raise SSRFError(req.url.host) return await handler(req) class SSRFConnector(TCPConnector): async def _resolve_host( self, host: str, port: int, traces: Sequence[Trace] | None = None ) -> list[ResolveResult]: res = await super()._resolve_host(host, port, traces) # WARNING: This is a simplified example - should also check ::1, private ranges, etc. if any(r["host"] in {"127.0.0.1"} for r in res): raise SSRFError() return res ================================================ FILE: docs/conf.py ================================================ #!/usr/bin/env python3 # # aiohttp documentation build configuration file, created by # sphinx-quickstart on Wed Mar 5 12:35:35 2014. # # This file is execfile()d with the current directory set to its # containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import os import re from pathlib import Path PROJECT_ROOT_DIR = Path(__file__).parents[1].resolve() IS_RELEASE_ON_RTD = ( os.getenv("READTHEDOCS", "False") == "True" and os.environ["READTHEDOCS_VERSION_TYPE"] == "tag" ) if IS_RELEASE_ON_RTD: tags.add("is_release") _docs_path = os.path.dirname(__file__) _version_path = os.path.abspath( os.path.join(_docs_path, "..", "aiohttp", "__init__.py") ) with open(_version_path, encoding="latin1") as fp: try: _version_info = re.search( r'^__version__ = "' r"(?P\d+)" r"\.(?P\d+)" r"\.(?P\d+)" r'(?P.*)?"$', fp.read(), re.M, ).groupdict() except IndexError: raise RuntimeError("Unable to determine version.") # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ # stdlib-party extensions: "sphinx.ext.autodoc", "sphinx.ext.extlinks", "sphinx.ext.graphviz", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", # Third-party extensions: "sphinxcontrib.towncrier.ext", # provides `towncrier-draft-entries` directive ] try: import sphinxcontrib.spelling # noqa extensions.append("sphinxcontrib.spelling") except ImportError: pass intersphinx_mapping = { "pytest": ("http://docs.pytest.org/en/latest/", None), "python": ("http://docs.python.org/3", None), "multidict": ("https://multidict.readthedocs.io/en/stable/", None), "propcache": ("https://propcache.aio-libs.org/en/stable", None), "yarl": ("https://yarl.readthedocs.io/en/stable/", None), "aiosignal": ("https://aiosignal.readthedocs.io/en/stable/", None), "aiohttpjinja2": ("https://aiohttp-jinja2.readthedocs.io/en/stable/", None), "aiohttpremotes": ("https://aiohttp-remotes.readthedocs.io/en/stable/", None), "aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None), "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), "aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None), "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/latest/", None), "isal": ("https://python-isal.readthedocs.io/en/stable/", None), "zlib_ng": ("https://python-zlib-ng.readthedocs.io/en/stable/", None), } # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The suffix of source filenames. source_suffix = ".rst" # The encoding of source files. # source_encoding = 'utf-8-sig' # The master toctree document. master_doc = "index" # -- Project information ----------------------------------------------------- github_url = "https://github.com" github_repo_org = "aio-libs" github_repo_name = "aiohttp" github_repo_slug = f"{github_repo_org}/{github_repo_name}" github_repo_url = f"{github_url}/{github_repo_slug}" github_sponsors_url = f"{github_url}/sponsors" project = github_repo_name copyright = f"{project} contributors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = "{major}.{minor}".format(**_version_info) # The full version, including alpha/beta/rc tags. release = "{major}.{minor}.{patch}{tag}".format(**_version_info) # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: # today = '' # Else, today_fmt is used as the format for a strftime call. # today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. # default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. # add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). # add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. # show_authors = False # The name of the Pygments (syntax highlighting) style to use. # pygments_style = 'sphinx' # The default language to highlight source code in. highlight_language = "python3" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. # keep_warnings = False # -- Extension configuration ------------------------------------------------- # -- Options for extlinks extension --------------------------------------- extlinks = { "issue": (f"{github_repo_url}/issues/%s", "#%s"), "pr": (f"{github_repo_url}/pull/%s", "PR #%s"), "commit": (f"{github_repo_url}/commit/%s", "%s"), "gh": (f"{github_url}/%s", "GitHub: %s"), "user": (f"{github_sponsors_url}/%s", "@%s"), } # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = "aiohttp_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { "description": "Async HTTP client/server for asyncio and Python", "canonical_url": "http://docs.aiohttp.org/en/stable/", "github_user": github_repo_org, "github_repo": github_repo_name, "github_button": True, "github_type": "star", "github_banner": True, "badges": [ { "image": f"{github_repo_url}/workflows/CI/badge.svg", "target": f"{github_repo_url}/actions?query=workflow%3ACI", "height": "20", "alt": "Azure Pipelines CI status", }, { "image": f"https://codecov.io/github/{github_repo_slug}/coverage.svg?branch=master", "target": f"https://codecov.io/github/{github_repo_slug}", "height": "20", "alt": "Code coverage status", }, { "image": f"https://badge.fury.io/py/{project}.svg", "target": f"https://badge.fury.io/py/{project}", "height": "20", "alt": "Latest PyPI package version", }, { "image": "https://badges.gitter.im/Join%20Chat.svg", "target": f"https://gitter.im/{github_repo_org}/Lobby", "height": "20", "alt": "Chat on Gitter", }, ], } html_css_files = [ "css/logo-adjustments.css", ] # Add any paths that contain custom themes here, relative to this directory. # html_theme_path = [alabaster.get_path()] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". # html_title = None # A shorter title for the navigation bar. Default is the same as html_title. # html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. html_logo = "aiohttp-plain.svg" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. html_favicon = "favicon.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. # html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. # html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. # html_use_smartypants = True # Custom sidebar templates, maps document names to template names. html_sidebars = { "**": [ "about.html", "navigation.html", "searchbox.html", ] } # Additional templates that should be rendered to pages, maps page names to # template names. # html_additional_pages = {} # If false, no module index is generated. # html_domain_indices = True # If false, no index is generated. # html_use_index = True # If true, the index is split into individual pages for each letter. # html_split_index = False # If true, links to the reST sources are added to the pages. # html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. # html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. # html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. # html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). # html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = f"{project}doc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # 'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ ( "index", f"{project}.tex", f"{project} Documentation", f"{project} contributors", "manual", ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. # latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. # latex_use_parts = False # If true, show page references after internal links. # latex_show_pagerefs = False # If true, show URL addresses after external links. # latex_show_urls = False # Documents to append as an appendix to all manuals. # latex_appendices = [] # If false, no module index is generated. # latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [("index", project, f"{project} Documentation", [project], 1)] # If true, show URL addresses after external links. # man_show_urls = False # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ( "index", project, f"{project} Documentation", "Aiohttp contributors", project, "One line description of project.", "Miscellaneous", ), ] # Documents to append as an appendix to all manuals. # texinfo_appendices = [] # If false, no module index is generated. # texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. # texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. # texinfo_no_detailmenu = False # ------------------------------------------------------------------------- nitpicky = True nitpick_ignore = [ ("py:mod", "aiohttp"), # undocumented, no `.. currentmodule:: aiohttp` in docs ("py:class", "aiohttp.SimpleCookie"), # undocumented ("py:class", "aiohttp.web.RequestHandler"), # undocumented ("py:class", "aiohttp.NamedPipeConnector"), # undocumented ("py:class", "aiohttp.protocol.HttpVersion"), # undocumented ("py:class", "HttpVersion"), # undocumented ("py:class", "aiohttp.payload.Payload"), # undocumented ("py:class", "Payload"), # undocumented ("py:class", "aiohttp.resolver.AsyncResolver"), # undocumented ("py:class", "aiohttp.resolver.ThreadedResolver"), # undocumented ("py:func", "aiohttp.ws_connect"), # undocumented ("py:meth", "start"), # undocumented ("py:exc", "aiohttp.ClientHttpProxyError"), # undocumented ("py:class", "asyncio.AbstractServer"), # undocumented ("py:mod", "aiohttp.test_tools"), # undocumented ("py:class", "list of pairs"), # undocumented ("py:class", "aiohttp.protocol.HttpVersion"), # undocumented ("py:meth", "aiohttp.ClientSession.request"), # undocumented ("py:class", "aiohttp.StreamWriter"), # undocumented ("py:attr", "aiohttp.StreamResponse.body"), # undocumented ("py:class", "aiohttp.payload.StringPayload"), # undocumented ("py:meth", "aiohttp.web.Application.copy"), # undocumented ("py:meth", "asyncio.AbstractEventLoop.create_server"), # undocumented ("py:data", "aiohttp.log.server_logger"), # undocumented ("py:data", "aiohttp.log.access_logger"), # undocumented ("py:data", "aiohttp.helpers.AccessLogger"), # undocumented ("py:attr", "helpers.AccessLogger.LOG_FORMAT"), # undocumented ("py:meth", "aiohttp.web.AbstractRoute.url"), # undocumented ("py:class", "aiohttp.web.MatchedSubAppResource"), # undocumented ("py:attr", "body"), # undocumented ("py:class", "socket.socket"), # undocumented ("py:func", "socket.socket"), # undocumented ("py:class", "socket.AddressFamily"), # undocumented ("py:obj", "logging.DEBUG"), # undocumented ("py:class", "aiohttp.abc.AbstractAsyncAccessLogger"), # undocumented ("py:meth", "aiohttp.web.Response.write_eof"), # undocumented ("py:meth", "aiohttp.payload.Payload.set_content_disposition"), # undocumented ("py:class", "cgi.FieldStorage"), # undocumented ("py:meth", "aiohttp.web.UrlDispatcher.register_resource"), # undocumented ("py:func", "aiohttp_debugtoolbar.setup"), # undocumented ("py:class", "socket.SocketKind"), # undocumented ] # -- Options for towncrier_draft extension ----------------------------------- towncrier_draft_autoversion_mode = "draft" # or: 'sphinx-version', 'sphinx-release' towncrier_draft_include_empty = True towncrier_draft_working_directory = PROJECT_ROOT_DIR # Not yet supported: towncrier_draft_config_path = 'pyproject.toml' # relative to cwd ================================================ FILE: docs/contributing-admins.rst ================================================ :orphan: Instructions for aiohttp admins =============================== This page is intended to document certain processes for admins of the aiohttp repository. For regular contributors, return to :doc:`contributing`. .. contents:: :local: Creating a new release ---------------------- .. note:: The example commands assume that ``origin`` refers to the ``aio-libs`` repository. To create a new release: #. Start on the branch for the release you are planning (e.g. ``3.8`` for v3.8.6): ``git checkout 3.8 && git pull`` #. Update the version number in ``__init__.py``. #. Run ``towncrier``. #. Check and cleanup the changes in ``CHANGES.rst``. #. Checkout a new branch: e.g. ``git checkout -b release/v3.8.6`` #. Commit and create a PR. Verify the changelog and release notes look good on Read the Docs. Once PR is merged, continue. #. Go back to the release branch: e.g. ``git checkout 3.8 && git pull`` #. Add a tag: e.g. ``git tag -a v3.8.6 -m 'Release 3.8.6' -s`` #. Push the tag: e.g. ``git push origin v3.8.6`` #. Monitor CI to ensure release process completes without errors. Once released, we need to complete some cleanup steps (no further steps are needed for non-stable releases though). If doing a patch release, we need to do the below steps twice, first merge into the newer release branch (e.g. 3.8 into 3.9) and then to master (e.g. 3.9 into master). If a new minor release, then just merge to master. #. Switch to target branch: e.g. ``git checkout 3.9 && git pull`` #. Start a merge: e.g. ``git merge 3.8 --no-commit --no-ff --gpg-sign`` #. Carefully review the changes and revert anything that should not be included (most things outside the changelog). #. To ensure change fragments are cleaned up properly, run: ``python tools/cleanup_changes.py`` #. Commit the merge (must be a normal merge commit, not squashed). #. Push the branch directly to Github (because a PR would get squashed). When pushing, you may get a rejected message. Follow these steps to resolve: #. Checkout to a new branch and push: e.g. ``git checkout -b do-not-merge && git push`` #. Open a *draft* PR with a title of 'DO NOT MERGE'. #. Once the CI has completed on that branch, you should be able to switch back and push the target branch (as tests have passed on the merge commit now). #. This should automatically consider the PR merged and delete the temporary branch. Back on the original release branch, bump the version number and append ``.dev0`` in ``__init__.py``. Post the release announcement to social media: - BlueSky: https://bsky.app/profile/aiohttp.org and re-post to https://bsky.app/profile/aio-libs.org - Mastodon: https://fosstodon.org/@aiohttp and re-post to https://fosstodon.org/@aio_libs If doing a minor release: #. Create a new release branch for future features to go to: e.g. ``git checkout -b 3.10 3.9 && git push`` #. Update both ``target-branch`` backports for Dependabot to reference the new branch name in ``.github/dependabot.yml``. #. Delete the older backport label (e.g. backport-3.8): https://github.com/aio-libs/aiohttp/labels #. Add a new backport label (e.g. backport-3.10). ================================================ FILE: docs/contributing.rst ================================================ .. _aiohttp-contributing: Contributing ============ (:doc:`contributing-admins`) Instructions for contributors ----------------------------- In order to make a clone of the GitHub_ repo: open the link and press the "Fork" button on the upper-right menu of the web page. I hope everybody knows how to work with git and github nowadays :) Workflow is pretty straightforward: 0. Make sure you are reading the latest version of this document. It can be found in the GitHub_ repo in the ``docs`` subdirectory. 1. Clone the GitHub_ repo using the ``--recurse-submodules`` argument 2. Setup your machine with the required development environment 3. Make a change 4. Make sure all tests passed 5. Add a file into the ``CHANGES`` folder (see `Changelog update`_ for how). 6. Commit changes to your own aiohttp clone 7. Make a pull request from the github page of your clone against the master branch 8. Optionally make backport Pull Request(s) for landing a bug fix into released aiohttp versions. .. note:: The project uses *Squash-and-Merge* strategy for *GitHub Merge* button. Basically it means that there is **no need to rebase** a Pull Request against *master* branch. Just ``git merge`` *master* into your working copy (a fork) if needed. The Pull Request is automatically squashed into the single commit once the PR is accepted. .. note:: GitHub issue and pull request threads are automatically locked when there has not been any recent activity for one year. Please open a `new issue `_ for related bugs. If you feel like there are important points in the locked discussions, please include those excerpts into that new issue. Preconditions for running aiohttp test suite -------------------------------------------- We expect you to use a python virtual environment to run our tests. There are several ways to make a virtual environment. If you like to use *virtualenv* please run: .. code-block:: shell $ cd aiohttp $ virtualenv --python=`which python3` venv $ . venv/bin/activate For standard python *venv*: .. code-block:: shell $ cd aiohttp $ python3 -m venv venv $ . venv/bin/activate For *virtualenvwrapper*: .. code-block:: shell $ cd aiohttp $ mkvirtualenv --python=`which python3` aiohttp There are other tools like *pyvenv* but you know the rule of thumb now: create a python3 virtual environment and activate it. After that please install libraries required for development: .. code-block:: shell $ make install-dev .. note:: For now, the development tooling depends on ``make`` and assumes an Unix OS If you wish to contribute to aiohttp from a Windows machine, the easiest way is probably to `configure the WSL `_ so you can use the same instructions. If it's not possible for you or if it doesn't work, please contact us so we can find a solution together. Install pre-commit hooks: .. code-block:: shell $ pre-commit install .. warning:: If you plan to use temporary ``print()``, ``pdb`` or ``ipdb`` within the test suite, execute it with ``-s``: .. code-block:: shell $ pytest tests -s in order to run the tests without output capturing. Congratulations, you are ready to run the test suite! .. include:: ../vendor/README.rst Run autoformatter ----------------- The project uses black_ + isort_ formatters to keep the source code style. Please run `make fmt` after every change before starting tests. .. code-block:: shell $ make fmt Run aiohttp test suite ---------------------- After all the preconditions are met you can run tests typing the next command: .. code-block:: shell $ make test The command at first will run the *linters* (sorry, we don't accept pull requests with pyflakes, black, isort, or mypy errors). On *lint* success the tests will be run. Please take a look on the produced output. Any extra texts (print statements and so on) should be removed. Code coverage ------------- We use *codecov.io* as an indispensable tool for analyzing our coverage results. Visit https://codecov.io/gh/aio-libs/aiohttp to see coverage reports for the master branch, history, pull requests etc. We'll use an example from a real PR to demonstrate how we use this. Once the tests run in a PR, you'll see a comment posted by *codecov*. The most important thing to check here is whether there are any new missed or partial lines in the report: .. image:: _static/img/contributing-cov-comment.svg Here, the PR has introduced 1 miss and 2 partials. Now we click the link in the comment header to open the full report: .. image:: _static/img/contributing-cov-header.svg :alt: Codecov report Now, if we look through the diff under 'Files changed' we find one of our partials: .. image:: _static/img/contributing-cov-partial.svg :alt: A while loop with partial coverage. In this case, the while loop is never skipped in our tests. This is probably not worth writing a test for (and may be a situation that is impossible to trigger anyway), so we leave this alone. We're still missing a partial and a miss, so we switch to the 'Indirect changes' tab and take a look through the diff there. This time we find the remaining 2 lines: .. image:: _static/img/contributing-cov-miss.svg :alt: An if statement that isn't covered anymore. After reviewing the PR, we find that this code is no longer needed as the changes mean that this method will never be called under those conditions. Thanks to this report, we were able to remove some redundant code from a performance-critical part of our codebase (this check would have been run, probably multiple times, for every single incoming request). .. tip:: Sometimes the diff on *codecov.io* doesn't make sense. This is usually caused by the branch being out of sync with master. Try merging master into the branch and it will likely fix the issue. Failing that, try checking coverage locally as described in the next section. Other tools ----------- The browser extension https://docs.codecov.io/docs/browser-extension is also a useful tool for analyzing the coverage directly from *Files Changed* tab on the *GitHub Pull Request* review page. You can also produce coverage reports locally with ``make cov-dev`` or just adding ``--cov-report=html`` to ``pytest``. This will run the test suite and collect coverage information. Once finished, coverage results can be view by opening: ```console $ python -m webbrowser -n file://"$(pwd)"/htmlcov/index.html ``` Documentation ------------- We encourage documentation improvements. Please before making a Pull Request about documentation changes run: .. code-block:: shell $ make doc Once it finishes it will output the index html page ``open file:///.../aiohttp/docs/_build/html/index.html``. Go to the link and make sure your doc changes looks good. Spell checking -------------- We use ``pyenchant`` and ``sphinxcontrib-spelling`` for running spell checker for documentation: .. code-block:: shell $ make doc-spelling Unfortunately there are problems with running spell checker on MacOS X. To run spell checker on Linux box you should install it first: .. code-block:: shell $ sudo apt-get install enchant $ pip install sphinxcontrib-spelling Preparing a pull request ------------------------ When making a pull request, please include a short summary of the changes and a reference to any issue tickets that the PR is intended to solve. All PRs with code changes should include tests. All changes should include a changelog entry. Changelog update ---------------- .. include:: ../CHANGES/README.rst Making a pull request --------------------- After finishing all steps make a GitHub_ Pull Request with *master* base branch. Backporting ----------- All Pull Requests are created against *master* git branch. If the Pull Request is not a new functionality but bug fixing *backport* to maintenance branch would be desirable. *aiohttp* project committer may ask for making a *backport* of the PR into maintained branch(es), in this case he or she adds a github label like *needs backport to 3.1*. *Backporting* is performed *after* main PR merging into master. Please do the following steps: 1. Find *Pull Request's commit* for cherry-picking. *aiohttp* does *squashing* PRs on merging, so open your PR page on github and scroll down to message like ``asvetlov merged commit f7b8921 into master 9 days ago``. ``f7b8921`` is the required commit number. 2. Run `cherry_picker `_ tool for making backport PR (the tool is already pre-installed from ``./requirements/dev.txt``), e.g. ``cherry_picker f7b8921 3.1``. 3. In case of conflicts fix them and continue cherry-picking by ``cherry_picker --continue``. ``cherry_picker --abort`` stops the process. ``cherry_picker --status`` shows current cherry-picking status (like ``git status``) 4. After all conflicts are done the tool opens a New Pull Request page in a browser with pre-filed information. Create a backport Pull Request and wait for review/merging. 5. *aiohttp* *committer* should remove *backport Git label* after merging the backport. How to become an aiohttp committer ---------------------------------- Contribute! The easiest way is providing Pull Requests for issues in our bug tracker. But if you have a great idea for the library improvement -- please make an issue and Pull Request. The rules for committers are simple: 1. No wild commits! Everything should go through PRs. 2. Take a part in reviews. It's very important part of maintainer's activity. 3. Pickup issues created by others, especially if they are simple. 4. Keep test suite comprehensive. In practice it means leveling up coverage. 97% is not bad but we wish to have 100% someday. Well, 99% is good target too. 5. Don't hesitate to improve our docs. Documentation is a very important thing, it's the key for project success. The documentation should not only cover our public API but help newbies to start using the project and shed a light on non-obvious gotchas. After positive answer aiohttp committer creates an issue on github with the proposal for nomination. If the proposal will collect only positive votes and no strong objection -- you'll be a new member in our team. .. _GitHub: https://github.com/aio-libs/aiohttp .. _ipdb: https://pypi.python.org/pypi/ipdb .. _black: https://pypi.python.org/pypi/black .. _isort: https://pypi.python.org/pypi/isort ================================================ FILE: docs/deployment.rst ================================================ .. _aiohttp-deployment: ================= Server Deployment ================= There are several options for aiohttp server deployment: * Standalone server * Running a pool of backend servers behind of :term:`nginx`, HAProxy or other *reverse proxy server* * Using :term:`gunicorn` behind of *reverse proxy* Every method has own benefits and disadvantages. .. _aiohttp-deployment-standalone: Standalone ========== Just call :func:`aiohttp.web.run_app` function passing :class:`aiohttp.web.Application` instance. The method is very simple and could be the best solution in some trivial cases. But it does not utilize all CPU cores. For running multiple aiohttp server instances use *reverse proxies*. .. _aiohttp-deployment-nginx-supervisord: Nginx+supervisord ================= Running aiohttp servers behind :term:`nginx` makes several advantages. First, nginx is the perfect frontend server. It may prevent many attacks based on malformed http protocol etc. Second, running several aiohttp instances behind nginx allows to utilize all CPU cores. Third, nginx serves static files much faster than built-in aiohttp static file support. But this way requires more complex configuration. Nginx configuration -------------------- Here is short example of an Nginx configuration file. It does not cover all available Nginx options. For full details, read `Nginx tutorial `_ and `official Nginx documentation `_. First configure HTTP server itself: .. code-block:: nginx http { server { listen 80; client_max_body_size 4G; server_name example.com; location / { proxy_set_header Host $http_host; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_redirect off; proxy_buffering off; proxy_pass http://aiohttp; } location /static { # path for static files root /path/to/app/static; } } } This config listens on port ``80`` for a server named ``example.com`` and redirects everything to the ``aiohttp`` backend group. Also it serves static files from ``/path/to/app/static`` path as ``example.com/static``. Next we need to configure *aiohttp upstream group*: .. code-block:: nginx http { upstream aiohttp { # fail_timeout=0 means we always retry an upstream even if it failed # to return a good HTTP response # Unix domain servers server unix:/tmp/example_1.sock fail_timeout=0; server unix:/tmp/example_2.sock fail_timeout=0; server unix:/tmp/example_3.sock fail_timeout=0; server unix:/tmp/example_4.sock fail_timeout=0; # Unix domain sockets are used in this example due to their high performance, # but TCP/IP sockets could be used instead: # server 127.0.0.1:8081 fail_timeout=0; # server 127.0.0.1:8082 fail_timeout=0; # server 127.0.0.1:8083 fail_timeout=0; # server 127.0.0.1:8084 fail_timeout=0; } } All HTTP requests for ``http://example.com`` except ones for ``http://example.com/static`` will be redirected to ``example1.sock``, ``example2.sock``, ``example3.sock`` or ``example4.sock`` backend servers. By default, Nginx uses round-robin algorithm for backend selection. .. note:: Nginx is not the only existing *reverse proxy server*, but it's the most popular one. Alternatives like HAProxy may be used as well. Supervisord ----------- After configuring Nginx we need to start our aiohttp backends. It's best to use some tool for starting them automatically after a system reboot or backend crash. There are many ways to do it: Supervisord, Upstart, Systemd, Gaffer, Circus, Runit etc. Here we'll use `Supervisord `_ as an example: .. code-block:: cfg [program:aiohttp] numprocs = 4 numprocs_start = 1 process_name = example_%(process_num)s ; Unix socket paths are specified by command line. command=/path/to/aiohttp_example.py --path=/tmp/example_%(process_num)s.sock ; We can just as easily pass TCP port numbers: ; command=/path/to/aiohttp_example.py --port=808%(process_num)s user=nobody autostart=true autorestart=true aiohttp server -------------- The last step is preparing the aiohttp server to work with supervisord. Assuming we have properly configured :class:`aiohttp.web.Application` and port is specified by command line, the task is trivial: .. code-block:: python3 # aiohttp_example.py import argparse from aiohttp import web parser = argparse.ArgumentParser(description="aiohttp server example") parser.add_argument('--path') parser.add_argument('--port') if __name__ == '__main__': app = web.Application() # configure app args = parser.parse_args() web.run_app(app, path=args.path, port=args.port) For real use cases we perhaps need to configure other things like logging etc., but it's out of scope of the topic. .. _aiohttp-deployment-gunicorn: Nginx+Gunicorn ============== aiohttp can be deployed using `Gunicorn `_, which is based on a pre-fork worker model. Gunicorn launches your app as worker processes for handling incoming requests. As opposed to deployment with :ref:`bare Nginx `, this solution does not need to manually run several aiohttp processes and use a tool like supervisord to monitor them. But nothing is free: running aiohttp application under gunicorn is slightly slower. Prepare environment ------------------- You first need to setup your deployment environment. This example is based on `Ubuntu `_ 16.04. Create a directory for your application:: >> mkdir myapp >> cd myapp Create a Python virtual environment:: >> python3 -m venv venv >> source venv/bin/activate Now that the virtual environment is ready, we'll proceed to install aiohttp and gunicorn:: >> pip install gunicorn >> pip install aiohttp Application ----------- Lets write a simple application, which we will save to file. We'll name this file *my_app_module.py*:: from aiohttp import web async def index(request): return web.Response(text="Welcome home!") my_web_app = web.Application() my_web_app.router.add_get('/', index) Application factory ------------------- As an option an entry point could be a coroutine that accepts no parameters and returns an application instance:: from aiohttp import web async def index(request): return web.Response(text="Welcome home!") async def my_web_app(): app = web.Application() app.router.add_get('/', index) return app Start Gunicorn -------------- When `Running Gunicorn `_, you provide the name of the module, i.e. *my_app_module*, and the name of the app or application factory, i.e. *my_web_app*, along with other `Gunicorn Settings `_ provided as command line flags or in your config file. In this case, we will use: * the ``--bind`` flag to set the server's socket address; * the ``--worker-class`` flag to tell Gunicorn that we want to use a custom worker subclass instead of one of the Gunicorn default worker types; * you may also want to use the ``--workers`` flag to tell Gunicorn how many worker processes to use for handling requests. (See the documentation for recommendations on `How Many Workers? `_) * you may also want to use the ``--accesslog`` flag to enable the access log to be populated. (See :ref:`logging ` for more information.) The custom worker subclass is defined in ``aiohttp.GunicornWebWorker``:: >> gunicorn my_app_module:my_web_app --bind localhost:8080 --worker-class aiohttp.GunicornWebWorker [2017-03-11 18:27:21 +0000] [1249] [INFO] Starting gunicorn 19.7.1 [2017-03-11 18:27:21 +0000] [1249] [INFO] Listening at: http://127.0.0.1:8080 (1249) [2017-03-11 18:27:21 +0000] [1249] [INFO] Using worker: aiohttp.worker.GunicornWebWorker [2015-03-11 18:27:21 +0000] [1253] [INFO] Booting worker with pid: 1253 Gunicorn is now running and ready to serve requests to your app's worker processes. .. note:: If you want to use an alternative asyncio event loop `uvloop `_, you can use the ``aiohttp.GunicornUVLoopWebWorker`` worker class. Proxy through NGINX ---------------------- We can proxy our gunicorn workers through NGINX with a configuration like this: .. code-block:: nginx worker_processes 1; user nobody nogroup; events { worker_connections 1024; } http { ## Main Server Block server { ## Open by default. listen 80 default_server; server_name main; client_max_body_size 200M; ## Main site location. location / { proxy_pass http://127.0.0.1:8080; proxy_set_header Host $host; proxy_set_header X-Forwarded-Host $server_name; proxy_set_header X-Real-IP $remote_addr; } } } Since gunicorn listens for requests at our localhost address on port 8080, we can use the `proxy_pass `_ directive to send web traffic to our workers. If everything is configured correctly, we should reach our application at the ip address of our web server. Proxy through NGINX + SSL ---------------------------- Here is an example NGINX configuration setup to accept SSL connections: .. code-block:: nginx worker_processes 1; user nobody nogroup; events { worker_connections 1024; } http { ## SSL Redirect server { listen 80 default; return 301 https://$host$request_uri; } ## Main Server Block server { # Open by default. listen 443 ssl default_server; listen [::]:443 ssl default_server; server_name main; client_max_body_size 200M; ssl_certificate /etc/secrets/cert.pem; ssl_certificate_key /etc/secrets/key.pem; ## Main site location. location / { proxy_pass http://127.0.0.1:8080; proxy_set_header Host $host; proxy_set_header X-Forwarded-Host $server_name; proxy_set_header X-Real-IP $remote_addr; } } } The first server block accepts regular http connections on port 80 and redirects them to our secure SSL connection. The second block matches our previous example except we need to change our open port to https and specify where our SSL certificates are being stored with the ``ssl_certificate`` and ``ssl_certificate_key`` directives. During development, you may want to `create your own self-signed certificates for testing purposes `_ and use another service like `Let's Encrypt `_ when you are ready to move to production. More information ---------------- See the `official documentation `_ for more information about suggested nginx configuration. You can also find out more about `configuring for secure https connections as well. `_ Logging configuration --------------------- ``aiohttp`` and ``gunicorn`` use different format for specifying access log. By default aiohttp uses own defaults:: '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' For more information please read :ref:`Format Specification for Access Log `. Proxy through Apache at your own risk ------------------------------------- Issues have been reported using Apache2 in front of aiohttp server: `#2687 Intermittent 502 proxy errors when running behind Apache `. ================================================ FILE: docs/essays.rst ================================================ Essays ====== .. toctree:: new_router whats_new_1_1 migration_to_2xx whats_new_3_0 ================================================ FILE: docs/external.rst ================================================ Who uses aiohttp? ================= The list of *aiohttp* users: both libraries, big projects and web sites. Please don't hesitate to add your awesome project to the list by making a Pull Request on GitHub_. If you like the project -- please go to GitHub_ and press *Star* button! .. toctree:: third_party built_with powered_by .. _GitHub: https://github.com/aio-libs/aiohttp ================================================ FILE: docs/faq.rst ================================================ FAQ === .. contents:: :local: Are there plans for an @app.route decorator like in Flask? ---------------------------------------------------------- As of aiohttp 2.3, :class:`~aiohttp.web.RouteTableDef` provides an API similar to Flask's ``@app.route``. See :ref:`aiohttp-web-alternative-routes-definition`. Unlike Flask's ``@app.route``, :class:`~aiohttp.web.RouteTableDef` does not require an ``app`` in the module namespace (which often leads to circular imports). Instead, a :class:`~aiohttp.web.RouteTableDef` is decoupled from an application instance:: routes = web.RouteTableDef() @routes.get('/get') async def handle_get(request): ... @routes.post('/post') async def handle_post(request): ... app.router.add_routes(routes) Does aiohttp have a concept like Flask's "blueprint" or Django's "app"? ----------------------------------------------------------------------- If you're writing a large application, you may want to consider using :ref:`nested applications `, which are similar to Flask's "blueprints" or Django's "apps". See: :ref:`aiohttp-web-nested-applications`. How do I create a route that matches urls with a given prefix? -------------------------------------------------------------- You can do something like the following: :: app.router.add_route('*', '/path/to/{tail:.+}', sink_handler) The first argument, ``*``, matches any HTTP method (*GET, POST, OPTIONS*, etc). The second argument matches URLS with the desired prefix. The third argument is the handler function. Where do I put my database connection so handlers can access it? ---------------------------------------------------------------- :class:`aiohttp.web.Application` object supports the :class:`dict` interface and provides a place to store your database connections or any other resource you want to share between handlers. :: db_key = web.AppKey("db_key", DB) async def go(request): db = request.app[db_key] cursor = await db.cursor() await cursor.execute('SELECT 42') # ... return web.Response(status=200, text='ok') async def init_app(): app = Application() db = await create_connection(user='user', password='123') app[db_key] = db app.router.add_get('/', go) return app How can middleware store data for web handlers to use? ------------------------------------------------------ Both :class:`aiohttp.web.Request` and :class:`aiohttp.web.Application` support the :class:`dict` interface. Therefore, data may be stored inside a request object. :: request_id_key = web.RequestKey("request_id_key", str) @web.middleware async def request_id_middleware(request, handler): request[request_id_key] = "some_request_id" return await handler(request) async def handler(request): request_id = request[request_id_key] See https://github.com/aio-libs/aiohttp_session code for an example. The ``aiohttp_session.get_session(request)`` method uses ``SESSION_KEY`` for saving request-specific session information. As of aiohttp 3.0, all response objects are dict-like structures as well. .. _aiohttp_faq_parallel_event_sources: Can a handler receive incoming events from different sources in parallel? ------------------------------------------------------------------------- Yes. As an example, we may have two event sources: 1. WebSocket for events from an end user 2. Redis PubSub for events from other parts of the application The most native way to handle this is to create a separate task for PubSub handling. Parallel :meth:`aiohttp.web.WebSocketResponse.receive` calls are forbidden; a single task should perform WebSocket reading. However, other tasks may use the same WebSocket object for sending data to peers. :: async def handler(request): ws = web.WebSocketResponse() await ws.prepare(request) task = asyncio.create_task( read_subscription(ws, request.app[redis_key])) try: async for msg in ws: # handle incoming messages # use ws.send_str() to send data back ... finally: task.cancel() async def read_subscription(ws, redis): channel, = await redis.subscribe('channel:1') try: async for msg in channel.iter(): answer = process_the_message(msg) # your function here await ws.send_str(answer) finally: await redis.unsubscribe('channel:1') .. _aiohttp_faq_terminating_websockets: How do I programmatically close a WebSocket server-side? -------------------------------------------------------- Let's say we have an application with two endpoints: 1. ``/echo`` a WebSocket echo server that authenticates the user 2. ``/logout_user`` that, when invoked, closes all open WebSockets for that user. One simple solution is to keep a shared registry of WebSocket responses for a user in the :class:`aiohttp.web.Application` instance and call :meth:`aiohttp.web.WebSocketResponse.close` on all of them in ``/logout_user`` handler:: async def echo_handler(request): ws = web.WebSocketResponse() user_id = authenticate_user(request) await ws.prepare(request) request.app[websockets_key][user_id].add(ws) try: async for msg in ws: ws.send_str(msg.data) finally: request.app[websockets_key][user_id].remove(ws) return ws async def logout_handler(request): user_id = authenticate_user(request) ws_closers = [ws.close() for ws in request.app[websockets_key][user_id] if not ws.closed] # Watch out, this will keep us from returning the response # until all are closed ws_closers and await asyncio.gather(*ws_closers) return web.Response(text='OK') def main(): loop = asyncio.get_event_loop() app = web.Application() app.router.add_route('GET', '/echo', echo_handler) app.router.add_route('POST', '/logout', logout_handler) app[websockets_key] = defaultdict(set) web.run_app(app, host='localhost', port=8080) How do I make a request from a specific IP address? --------------------------------------------------- If your system has several IP interfaces, you may choose one which will be used used to bind a socket locally:: conn = aiohttp.TCPConnector(local_addr=('127.0.0.1', 0)) async with aiohttp.ClientSession(connector=conn) as session: ... .. seealso:: :class:`aiohttp.TCPConnector` and ``local_addr`` parameter. What is the API stability and deprecation policy? ------------------------------------------------- *aiohttp* follows strong `Semantic Versioning `_ (SemVer). Obsolete attributes and methods are marked as *deprecated* in the documentation and raise :class:`DeprecationWarning` upon usage. Assume aiohttp ``X.Y.Z`` where ``X`` is major version, ``Y`` is minor version and ``Z`` is bugfix number. For example, if the latest released version is ``aiohttp==3.0.6``: ``3.0.7`` fixes some bugs but have no new features. ``3.1.0`` introduces new features and can deprecate some API but never remove it, also all bug fixes from previous release are merged. ``4.0.0`` removes all deprecations collected from ``3.Y`` versions **except** deprecations from the **last** ``3.Y`` release. These deprecations will be removed by ``5.0.0``. Unfortunately we may have to break these rules when a **security vulnerability** is found. If a security problem cannot be fixed without breaking backward compatibility, a bugfix release may break compatibility. This is unlikely, but possible. All backward incompatible changes are explicitly marked in :ref:`the changelog `. How do I enable gzip compression globally for my entire application? -------------------------------------------------------------------- It's impossible. Choosing what to compress and what not to compress is a tricky matter. If you need global compression, write a custom middleware. Or enable compression in NGINX (you are deploying aiohttp behind reverse proxy, right?). How do I manage a ClientSession within a web server? ---------------------------------------------------- :class:`aiohttp.ClientSession` should be created once for the lifetime of the server in order to benefit from connection pooling. Sessions save cookies internally. If you don't need cookie processing, use :class:`aiohttp.DummyCookieJar`. If you need separate cookies for different http calls but process them in logical chains, use a single :class:`aiohttp.TCPConnector` with separate client sessions and ``connector_owner=False``. How do I access database connections from a subapplication? ----------------------------------------------------------- Restricting access from subapplication to main (or outer) app is a deliberate choice. A subapplication is an isolated unit by design. If you need to share a database object, do it explicitly:: subapp[db_key] = mainapp[db_key] mainapp.add_subapp("/prefix", subapp) This can also be done from a :ref:`cleanup context`:: @contextlib.asynccontextmanager async def db_context(app: web.Application) -> AsyncIterator[None]: async with create_db() as db: mainapp[db_key] = mainapp[subapp_key][db_key] = db yield mainapp[subapp_key] = subapp mainapp.add_subapp("/prefix", subapp) mainapp.cleanup_ctx.append(db_context) How do I perform operations in a request handler after sending the response? ---------------------------------------------------------------------------- Middlewares can be written to handle post-response operations, but they run after every request. You can explicitly send the response by calling :meth:`aiohttp.web.Response.write_eof`, which starts sending before the handler returns, giving you a chance to execute follow-up operations:: def ping_handler(request): """Send PONG and increase DB counter.""" # explicitly send the response resp = web.json_response({'message': 'PONG'}) await resp.prepare(request) await resp.write_eof() # increase the pong count request.app[db_key].inc_pong() return resp A :class:`aiohttp.web.Response` object must be returned. This is required by aiohttp web contracts, even though the response has already been sent. How do I make sure my custom middleware response will behave correctly? ------------------------------------------------------------------------ Sometimes your middleware handlers might need to send a custom response. This is just fine as long as you always create a new :class:`aiohttp.web.Response` object when required. The response object is a Finite State Machine. Once it has been dispatched by the server, it will reach its final state and cannot be used again. The following middleware will make the server hang, once it serves the second response:: from aiohttp import web def misbehaved_middleware(): # don't do this! cached = web.Response(status=200, text='Hi, I am cached!') async def middleware(request, handler): # ignoring response for the sake of this example _res = handler(request) return cached return middleware The rule of thumb is *one request, one response*. Why is creating a ClientSession outside of an event loop dangerous? ------------------------------------------------------------------- Short answer is: life-cycle of all asyncio objects should be shorter than life-cycle of event loop. Full explanation is longer. All asyncio object should be correctly finished/disconnected/closed before event loop shutdown. Otherwise user can get unexpected behavior. In the best case it is a warning about unclosed resource, in the worst case the program just hangs, awaiting for coroutine is never resumed etc. Consider the following code from ``mod.py``:: import aiohttp session = aiohttp.ClientSession() async def fetch(url): async with session.get(url) as resp: return await resp.text() The session grabs current event loop instance and stores it in a private variable. The main module imports the module and installs ``uvloop`` (an alternative fast event loop implementation). ``main.py``:: import asyncio import uvloop import mod asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.run(main()) The code is broken: ``session`` is bound to default ``asyncio`` loop on import time but the loop is changed **after the import** by ``set_event_loop()``. As result ``fetch()`` call hangs. To avoid import dependency hell *aiohttp* encourages creation of ``ClientSession`` from async function. The same policy works for ``web.Application`` too. Another use case is unit test writing. Very many test libraries (*aiohttp test tools* first) creates a new loop instance for every test function execution. It's done for sake of tests isolation. Otherwise pending activity (timers, network packets etc.) from previous test may interfere with current one producing very cryptic and unstable test failure. Note: *class variables* are hidden globals actually. The following code has the same problem as ``mod.py`` example, ``session`` variable is the hidden global object:: class A: session = aiohttp.ClientSession() async def fetch(self, url): async with session.get(url) as resp: return await resp.text() ================================================ FILE: docs/glossary.rst ================================================ .. _aiohttp-glossary: ========== Glossary ========== .. if you add new entries, keep the alphabetical sorting! .. glossary:: :sorted: aiodns DNS resolver for asyncio. https://pypi.python.org/pypi/aiodns asyncio The library for writing single-threaded concurrent code using coroutines, multiplexing I/O access over sockets and other resources, running network clients and servers, and other related primitives. Reference implementation of :pep:`3156` https://pypi.python.org/pypi/asyncio/ Brotli Brotli is a generic-purpose lossless compression algorithm that compresses data using a combination of a modern variant of the LZ77 algorithm, Huffman coding and second order context modeling, with a compression ratio comparable to the best currently available general-purpose compression methods. It is similar in speed with deflate but offers more dense compression. The specification of the Brotli Compressed Data Format is defined :rfc:`7932` https://pypi.org/project/Brotli/ brotlicffi An alternative implementation of :term:`Brotli` built using the CFFI library. This implementation supports PyPy correctly. https://pypi.org/project/brotlicffi/ callable Any object that can be called. Use :func:`callable` to check that. gunicorn Gunicorn 'Green Unicorn' is a Python WSGI HTTP Server for UNIX. http://gunicorn.org/ IDNA An Internationalized Domain Name in Applications (IDNA) is an industry standard for encoding Internet Domain Names that contain in whole or in part, in a language-specific script or alphabet, such as Arabic, Chinese, Cyrillic, Tamil, Hebrew or the Latin alphabet-based characters with diacritics or ligatures, such as French. These writing systems are encoded by computers in multi-byte Unicode. Internationalized domain names are stored in the Domain Name System as ASCII strings using Punycode transcription. keep-alive A technique for communicating between HTTP client and server when connection is not closed after sending response but kept open for sending next request through the same socket. It makes communication faster by getting rid of connection establishment for every request. nginx Nginx [engine x] is an HTTP and reverse proxy server, a mail proxy server, and a generic TCP/UDP proxy server. https://nginx.org/en/ percent-encoding A mechanism for encoding information in a Uniform Resource Locator (URL) if URL parts don't fit in safe characters space. requests Currently the most popular synchronous library to make HTTP requests in Python. https://requests.readthedocs.io requoting Applying :term:`percent-encoding` to non-safe symbols and decode percent encoded safe symbols back. According to :rfc:`3986` allowed path symbols are:: allowed = unreserved / pct-encoded / sub-delims / ":" / "@" / "/" pct-encoded = "%" HEXDIG HEXDIG unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" resource A concept reflects the HTTP **path**, every resource corresponds to *URI*. May have a unique name. Contains :term:`route`\'s for different HTTP methods. route A part of :term:`resource`, resource's *path* coupled with HTTP method. web-handler An endpoint that returns HTTP response. websocket A protocol providing full-duplex communication channels over a single TCP connection. The WebSocket protocol was standardized by the IETF as :rfc:`6455` yarl A library for operating with URL objects. https://pypi.python.org/pypi/yarl Environment Variables ===================== .. envvar:: AIOHTTP_NO_EXTENSIONS If set to a non-empty value while building from source, aiohttp will be built without speedups written as C extensions. This option is primarily useful for debugging. .. envvar:: AIOHTTP_USE_SYSTEM_DEPS If set to a non-empty value while building from source, aiohttp will be built against the system installation of llhttp rather than the vendored library. This option is primarily meant to be used by downstream redistributors. .. envvar:: NETRC If set, HTTP Basic Auth will be read from the file pointed to by this environment variable, rather than from :file:`~/.netrc`. .. seealso:: ``.netrc`` documentation: https://www.gnu.org/software/inetutils/manual/html_node/The-_002enetrc-file.html ================================================ FILE: docs/http_request_lifecycle.rst ================================================ .. _aiohttp-request-lifecycle: The aiohttp Request Lifecycle ============================= Why is aiohttp client API that way? -------------------------------------- The first time you use aiohttp, you'll notice that a simple HTTP request is performed not with one, but with up to three steps: .. code-block:: python async with aiohttp.ClientSession() as session: async with session.get('http://python.org') as response: print(await response.text()) It's especially unexpected when coming from other libraries such as the very popular :term:`requests`, where the "hello world" looks like this: .. code-block:: python response = requests.get('http://python.org') print(response.text) So why is the aiohttp snippet so verbose? Because aiohttp is asynchronous, its API is designed to make the most out of non-blocking network operations. In code like this, requests will block three times, and does it transparently, while aiohttp gives the event loop three opportunities to switch context: - When doing the ``.get()``, both libraries send a GET request to the remote server. For aiohttp, this means asynchronous I/O, which is marked here with an ``async with`` that gives you the guarantee that not only it doesn't block, but that it's cleanly finalized. - When doing ``response.text`` in requests, you just read an attribute. The call to ``.get()`` already preloaded and decoded the entire response payload, in a blocking manner. aiohttp loads only the headers when ``.get()`` is executed, letting you decide to pay the cost of loading the body afterward, in a second asynchronous operation. Hence the ``await response.text()``. - ``async with aiohttp.ClientSession()`` does not perform I/O when entering the block, but at the end of it, it will ensure all remaining resources are closed correctly. Again, this is done asynchronously and must be marked as such. The session is also a performance tool, as it manages a pool of connections for you, allowing you to reuse them instead of opening and closing a new one at each request. You can even `manage the pool size by passing a connector object `_. Using a session as a best practice ----------------------------------- The requests library does in fact also provides a session system. Indeed, it lets you do: .. code-block:: python with requests.Session() as session: response = session.get('http://python.org') print(response.text) It's just not the default behavior, nor is it advertised early in the documentation. Because of this, most users take a hit in performance, but can quickly start hacking. And for requests, it's an understandable trade-off, since its goal is to be "HTTP for humans" and simplicity has always been more important than performance in this context. However, if one uses aiohttp, one chooses asynchronous programming, a paradigm that makes the opposite trade-off: more verbosity for better performance. And so the library default behavior reflects this, encouraging you to use performant best practices from the start. How to use the ClientSession ? ------------------------------- By default the :class:`aiohttp.ClientSession` object will hold a connector with a maximum of 100 connections, putting the rest in a queue. This is quite a big number, this means you must be connected to a hundred different servers (not pages!) concurrently before even having to consider if your task needs resource adjustment. In fact, you can picture the session object as a user starting and closing a browser: it wouldn't make sense to do that every time you want to load a new tab. So you are expected to reuse a session object and make many requests from it. For most scripts and average-sized software, this means you can create a single session, and reuse it for the entire execution of the program. You can even pass the session around as a parameter in functions. For example, the typical "hello world": .. code-block:: python import aiohttp import asyncio async def main(): async with aiohttp.ClientSession() as session: async with session.get('http://python.org') as response: html = await response.text() print(html) asyncio.run(main()) Can become this: .. code-block:: python import aiohttp import asyncio async def fetch(session, url): async with session.get(url) as response: return await response.text() async def main(): async with aiohttp.ClientSession() as session: html = await fetch(session, 'http://python.org') print(html) asyncio.run(main()) On more complex code bases, you can even create a central registry to hold the session object from anywhere in the code, or a higher level ``Client`` class that holds a reference to it. When to create more than one session object then? It arises when you want more granularity with your resources management: - you want to group connections by a common configuration. e.g: sessions can set cookies, headers, timeout values, etc. that are shared for all connections they hold. - you need several threads and want to avoid sharing a mutable object between them. - you want several connection pools to benefit from different queues and assign priorities. e.g: one session never uses the queue and is for high priority requests, the other one has a small concurrency limit and a very long queue, for non important requests. ================================================ FILE: docs/index.rst ================================================ .. aiohttp documentation master file, created by sphinx-quickstart on Wed Mar 5 12:35:35 2014. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. ================== Welcome to AIOHTTP ================== Asynchronous HTTP Client/Server for :term:`asyncio` and Python. Current version is |release|. .. _GitHub: https://github.com/aio-libs/aiohttp Key Features ============ - Supports both :ref:`aiohttp-client` and :ref:`HTTP Server `. - Supports both :ref:`Server WebSockets ` and :ref:`Client WebSockets ` out-of-the-box without the Callback Hell. - Web-server has :ref:`aiohttp-web-middlewares`, :ref:`aiohttp-web-signals` and pluggable routing. - Client supports :ref:`middleware ` for customizing request/response processing. .. _aiohttp-installation: Library Installation ==================== .. code-block:: bash $ pip install aiohttp For speeding up DNS resolving by client API you may install :term:`aiodns` as well. This option is highly recommended: .. code-block:: bash $ pip install aiodns Installing all speedups in one command -------------------------------------- The following will get you ``aiohttp`` along with :term:`aiodns` and ``Brotli`` in one bundle. No need to type separate commands anymore! .. code-block:: bash $ pip install aiohttp[speedups] Getting Started =============== Client example -------------- .. code-block:: python import aiohttp import asyncio async def main(): async with aiohttp.ClientSession() as session: async with session.get('http://python.org') as response: print("Status:", response.status) print("Content-type:", response.headers['content-type']) html = await response.text() print("Body:", html[:15], "...") asyncio.run(main()) This prints: .. code-block:: text Status: 200 Content-type: text/html; charset=utf-8 Body: ... Coming from :term:`requests` ? Read :ref:`why we need so many lines `. Server example: ---------------- .. code-block:: python from aiohttp import web async def handle(request): name = request.match_info.get('name', "Anonymous") text = "Hello, " + name return web.Response(text=text) app = web.Application() app.add_routes([web.get('/', handle), web.get('/{name}', handle)]) if __name__ == '__main__': web.run_app(app) For more information please visit :ref:`aiohttp-client` and :ref:`aiohttp-web` pages. Development mode ================ When writing your code, we recommend enabling Python's `development mode `_ (``python -X dev``). In addition to the extra features enabled for asyncio, aiohttp will: - Use a strict parser in the client code (which can help detect malformed responses from a server). - Enable some additional checks (resulting in warnings in certain situations). What's new in aiohttp 3? ======================== Go to :ref:`aiohttp_whats_new_3_0` page for aiohttp 3.0 major release changes. Tutorial ======== :ref:`Polls tutorial ` Source code =========== The project is hosted on GitHub_ Please feel free to file an issue on the `bug tracker `_ if you have found a bug or have some suggestion in order to improve the library. Dependencies ============ - *multidict* - *yarl* - *Optional* :term:`aiodns` for fast DNS resolving. The library is highly recommended. .. code-block:: bash $ pip install aiodns - *Optional* :term:`Brotli` or :term:`brotlicffi` for brotli (:rfc:`7932`) client compression support. .. code-block:: bash $ pip install Brotli Communication channels ====================== *aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions Feel free to post your questions and ideas here. *Matrix*: `#aio-libs:matrix.org `_ We support `Stack Overflow `_. Please add *aiohttp* tag to your question there. Contributing ============ Please read the :ref:`instructions for contributors` before making a Pull Request. Authors and License =================== The ``aiohttp`` package is written mostly by Nikolay Kim and Andrew Svetlov. It's *Apache 2* licensed and freely available. Feel free to improve this package and send a pull request to GitHub_. .. _aiohttp-backward-compatibility-policy: Policy for Backward Incompatible Changes ======================================== *aiohttp* keeps backward compatibility. When a new release is published that deprecates a *Public API* (method, class, function argument, etc.), the library will guarantee its usage for at least a year and half from the date of release. Deprecated APIs are reflected in their documentation, and their use will raise :exc:`DeprecationWarning`. However, if there is a strong reason, we may be forced to break this guarantee. The most likely reason would be a critical bug, such as a security issue, which cannot be solved without a major API change. We are working hard to keep these breaking changes as rare as possible. Table Of Contents ================= .. toctree:: :name: mastertoc :maxdepth: 2 client web utilities faq misc external contributing ================================================ FILE: docs/logging.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-logging: Logging ======= *aiohttp* uses standard :mod:`logging` for tracking the library activity. We have the following loggers enumerated by names: - ``'aiohttp.access'`` - ``'aiohttp.client'`` - ``'aiohttp.internal'`` - ``'aiohttp.server'`` - ``'aiohttp.web'`` - ``'aiohttp.websocket'`` You may subscribe to these loggers for getting logging messages. The page does not provide instructions for logging subscribing while the most friendly method is :func:`logging.config.dictConfig` for configuring whole loggers in your application. Logging does not work out of the box. It requires at least minimal ``'logging'`` configuration. Example of minimal working logger setup:: import logging from aiohttp import web app = web.Application() logging.basicConfig(level=logging.DEBUG) web.run_app(app, port=5000) .. versionadded:: 4.0.0 Access logs ----------- Access logs are enabled by default. If the `debug` flag is set, and the default logger ``'aiohttp.access'`` is used, access logs will be output to :obj:`~sys.stderr` if no handlers are attached. Furthermore, if the default logger has no log level set, the log level will be set to :obj:`logging.DEBUG`. This logging may be controlled by :meth:`aiohttp.web.AppRunner` and :func:`aiohttp.web.run_app`. To override the default logger, pass an instance of :class:`logging.Logger` to override the default logger. .. note:: Use ``web.run_app(app, access_log=None)`` to disable access logs. In addition, *access_log_format* may be used to specify the log format. .. _aiohttp-logging-access-log-format-spec: Format specification ^^^^^^^^^^^^^^^^^^^^ The library provides custom micro-language to specifying info about request and response: +--------------+---------------------------------------------------------+ | Option | Meaning | +==============+=========================================================+ | ``%%`` | The percent sign | +--------------+---------------------------------------------------------+ | ``%a`` | Remote IP-address | | | (IP-address of proxy if using reverse proxy) | +--------------+---------------------------------------------------------+ | ``%t`` | Time when the request was started to process | +--------------+---------------------------------------------------------+ | ``%P`` | The process ID of the child that serviced the request | +--------------+---------------------------------------------------------+ | ``%r`` | First line of request | +--------------+---------------------------------------------------------+ | ``%s`` | Response status code | +--------------+---------------------------------------------------------+ | ``%b`` | Size of response in bytes, including HTTP headers | +--------------+---------------------------------------------------------+ | ``%T`` | The time taken to serve the request, in seconds | +--------------+---------------------------------------------------------+ | ``%Tf`` | The time taken to serve the request, in seconds | | | with fraction in %.06f format | +--------------+---------------------------------------------------------+ | ``%D`` | The time taken to serve the request, in microseconds | +--------------+---------------------------------------------------------+ | ``%{FOO}i`` | ``request.headers['FOO']`` | +--------------+---------------------------------------------------------+ | ``%{FOO}o`` | ``response.headers['FOO']`` | +--------------+---------------------------------------------------------+ The default access log format is:: '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' .. versionadded:: 2.3.0 *access_log_class* introduced. Example of a drop-in replacement for the default access logger:: from aiohttp.abc import AbstractAccessLogger class AccessLogger(AbstractAccessLogger): def log(self, request, response, time): self.logger.info(f'{request.remote} ' f'"{request.method} {request.path} ' f'done in {time}s: {response.status}') @property def enabled(self): """Return True if logger is enabled. Override this property if logging is disabled to avoid the overhead of calculating details to feed the logger. This property may be omitted if logging is always enabled. """ return self.logger.isEnabledFor(logging.INFO) .. versionadded:: 4.0.0 ``AccessLogger.log()`` can now access any exception raised while processing the request with ``sys.exc_info()``. .. versionadded:: 4.0.0 If your logging needs to perform IO you can instead inherit from :class:`aiohttp.abc.AbstractAsyncAccessLogger`:: from aiohttp.abc import AbstractAsyncAccessLogger class AccessLogger(AbstractAsyncAccessLogger): async def log(self, request, response, time): logging_service = request.app['logging_service'] await logging_service.log(f'{request.remote} ' f'"{request.method} {request.path} ' f'done in {time}s: {response.status}') @property def enabled(self) -> bool: """Return True if logger is enabled. Override this property if logging is disabled to avoid the overhead of calculating details to feed the logger. """ return self.logger.isEnabledFor(logging.INFO) This also allows access to the results of coroutines on the ``request`` and ``response``, e.g. ``request.text()``. .. _gunicorn-accesslog: Gunicorn access logs ^^^^^^^^^^^^^^^^^^^^ When `Gunicorn `_ is used for :ref:`deployment `, its default access log format will be automatically replaced with the default aiohttp's access log format. If Gunicorn's option access_logformat_ is specified explicitly, it should use aiohttp's format specification. Gunicorn's access log works only if accesslog_ is specified explicitly in your config or as a command line option. This configuration can be either a path or ``'-'``. If the application uses a custom logging setup intercepting the ``'gunicorn.access'`` logger, accesslog_ should be set to ``'-'`` to prevent Gunicorn to create an empty access log file upon every startup. Error logs ---------- :mod:`aiohttp.web` uses a logger named ``'aiohttp.server'`` to store errors given on web requests handling. This log is enabled by default. To use a different logger name, pass *logger* (:class:`logging.Logger` instance) to the :meth:`aiohttp.web.AppRunner` constructor. .. _access_logformat: http://docs.gunicorn.org/en/stable/settings.html#access-log-format .. _accesslog: http://docs.gunicorn.org/en/stable/settings.html#accesslog ================================================ FILE: docs/make.bat ================================================ @ECHO OFF REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set BUILDDIR=_build set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . set I18NSPHINXOPTS=%SPHINXOPTS% . if NOT "%PAPER%" == "" ( set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% ) if "%1" == "" goto help if "%1" == "help" ( :help echo.Please use `make ^` where ^ is one of echo. html to make standalone HTML files echo. dirhtml to make HTML files named index.html in directories echo. singlehtml to make a single large HTML file echo. pickle to make pickle files echo. json to make JSON files echo. htmlhelp to make HTML files and a HTML help project echo. qthelp to make HTML files and a qthelp project echo. devhelp to make HTML files and a Devhelp project echo. epub to make an epub echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter echo. text to make text files echo. man to make manual pages echo. texinfo to make Texinfo files echo. gettext to make PO message catalogs echo. changes to make an overview over all changed/added/deprecated items echo. xml to make Docutils-native XML files echo. pseudoxml to make pseudoxml-XML files for display purposes echo. linkcheck to check all external links for integrity echo. doctest to run all doctests embedded in the documentation if enabled goto end ) if "%1" == "clean" ( for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i del /q /s %BUILDDIR%\* goto end ) %SPHINXBUILD% 2> nul if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) if "%1" == "html" ( %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/html. goto end ) if "%1" == "dirhtml" ( %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. goto end ) if "%1" == "singlehtml" ( %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. goto end ) if "%1" == "pickle" ( %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the pickle files. goto end ) if "%1" == "json" ( %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the JSON files. goto end ) if "%1" == "htmlhelp" ( %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run HTML Help Workshop with the ^ .hhp project file in %BUILDDIR%/htmlhelp. goto end ) if "%1" == "qthelp" ( %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run "qcollectiongenerator" with the ^ .qhcp project file in %BUILDDIR%/qthelp, like this: echo.^> qcollectiongenerator %BUILDDIR%\qthelp\aiohttp.qhcp echo.To view the help file: echo.^> assistant -collectionFile %BUILDDIR%\qthelp\aiohttp.ghc goto end ) if "%1" == "devhelp" ( %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp if errorlevel 1 exit /b 1 echo. echo.Build finished. goto end ) if "%1" == "epub" ( %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub if errorlevel 1 exit /b 1 echo. echo.Build finished. The epub file is in %BUILDDIR%/epub. goto end ) if "%1" == "latex" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex if errorlevel 1 exit /b 1 echo. echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. goto end ) if "%1" == "latexpdf" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex cd %BUILDDIR%/latex make all-pdf cd %BUILDDIR%/.. echo. echo.Build finished; the PDF files are in %BUILDDIR%/latex. goto end ) if "%1" == "latexpdfja" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex cd %BUILDDIR%/latex make all-pdf-ja cd %BUILDDIR%/.. echo. echo.Build finished; the PDF files are in %BUILDDIR%/latex. goto end ) if "%1" == "text" ( %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text if errorlevel 1 exit /b 1 echo. echo.Build finished. The text files are in %BUILDDIR%/text. goto end ) if "%1" == "man" ( %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man if errorlevel 1 exit /b 1 echo. echo.Build finished. The manual pages are in %BUILDDIR%/man. goto end ) if "%1" == "texinfo" ( %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo if errorlevel 1 exit /b 1 echo. echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. goto end ) if "%1" == "gettext" ( %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale if errorlevel 1 exit /b 1 echo. echo.Build finished. The message catalogs are in %BUILDDIR%/locale. goto end ) if "%1" == "changes" ( %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes if errorlevel 1 exit /b 1 echo. echo.The overview file is in %BUILDDIR%/changes. goto end ) if "%1" == "linkcheck" ( %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck if errorlevel 1 exit /b 1 echo. echo.Link check complete; look for any errors in the above output ^ or in %BUILDDIR%/linkcheck/output.txt. goto end ) if "%1" == "doctest" ( %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest if errorlevel 1 exit /b 1 echo. echo.Testing of doctests in the sources finished, look at the ^ results in %BUILDDIR%/doctest/output.txt. goto end ) if "%1" == "xml" ( %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml if errorlevel 1 exit /b 1 echo. echo.Build finished. The XML files are in %BUILDDIR%/xml. goto end ) if "%1" == "pseudoxml" ( %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml if errorlevel 1 exit /b 1 echo. echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. goto end ) :end ================================================ FILE: docs/migration_to_2xx.rst ================================================ .. _aiohttp-migration: Migration to 2.x ================ Client ------ chunking ^^^^^^^^ aiohttp does not support custom chunking sizes. It is up to the developer to decide how to chunk data streams. If chunking is enabled, aiohttp encodes the provided chunks in the "Transfer-encoding: chunked" format. aiohttp does not enable chunked encoding automatically even if a *transfer-encoding* header is supplied: *chunked* has to be set explicitly. If *chunked* is set, then the *Transfer-encoding* and *content-length* headers are disallowed. compression ^^^^^^^^^^^ Compression has to be enabled explicitly with the *compress* parameter. If compression is enabled, adding a *content-encoding* header is not allowed. Compression also enables the *chunked* transfer-encoding. Compression can not be combined with a *Content-Length* header. Client Connector ^^^^^^^^^^^^^^^^ 1. By default a connector object manages a total number of concurrent connections. This limit was a per host rule in version 1.x. In 2.x, the `limit` parameter defines how many concurrent connection connector can open and a new `limit_per_host` parameter defines the limit per host. By default there is no per-host limit. 2. BaseConnector.close is now a normal function as opposed to coroutine in version 1.x 3. BaseConnector.conn_timeout was moved to ClientSession ClientResponse.release ^^^^^^^^^^^^^^^^^^^^^^ Internal implementation was significantly redesigned. It is not required to call `release` on the response object. When the client fully receives the payload, the underlying connection automatically returns back to pool. If the payload is not fully read, the connection is closed Client exceptions ^^^^^^^^^^^^^^^^^ Exception hierarchy has been significantly modified. aiohttp now defines only exceptions that covers connection handling and server response misbehaviors. For developer specific mistakes, aiohttp uses python standard exceptions like ValueError or TypeError. Reading a response content may raise a ClientPayloadError exception. This exception indicates errors specific to the payload encoding. Such as invalid compressed data, malformed chunked-encoded chunks or not enough data that satisfy the content-length header. All exceptions are moved from `aiohttp.errors` module to top level `aiohttp` module. New hierarchy of exceptions: * `ClientError` - Base class for all client specific exceptions - `ClientResponseError` - exceptions that could happen after we get response from server * `WSServerHandshakeError` - web socket server response error - `ClientHttpProxyError` - proxy response - `ClientConnectionError` - exceptions related to low-level connection problems * `ClientOSError` - subset of connection errors that are initiated by an OSError exception - `ClientConnectorError` - connector related exceptions * `ClientProxyConnectionError` - proxy connection initialization error - `ServerConnectionError` - server connection related errors * `ServerDisconnectedError` - server disconnected * `ServerTimeoutError` - server operation timeout, (read timeout, etc) * `ServerFingerprintMismatch` - server fingerprint mismatch - `ClientPayloadError` - This exception can only be raised while reading the response payload if one of these errors occurs: invalid compression, malformed chunked encoding or not enough data that satisfy content-length header. Client payload (form-data) ^^^^^^^^^^^^^^^^^^^^^^^^^^ To unify form-data/payload handling a new `Payload` system was introduced. It handles customized handling of existing types and provide implementation for user-defined types. 1. FormData.__call__ does not take an encoding arg anymore and its return value changes from an iterator or bytes to a Payload instance. aiohttp provides payload adapters for some standard types like `str`, `byte`, `io.IOBase`, `StreamReader` or `DataQueue`. 2. a generator is not supported as data provider anymore, `streamer` can be used instead. For example, to upload data from file:: @aiohttp.streamer def file_sender(writer, file_name=None): with open(file_name, 'rb') as f: chunk = f.read(2**16) while chunk: yield from writer.write(chunk) chunk = f.read(2**16) # Then you can use `file_sender` like this: async with session.post('http://httpbin.org/post', data=file_sender(file_name='huge_file')) as resp: print(await resp.text()) Various ^^^^^^^ 1. the `encoding` parameter is deprecated in `ClientSession.request()`. Payload encoding is controlled at the payload level. It is possible to specify an encoding for each payload instance. 2. the `version` parameter is removed in `ClientSession.request()` client version can be specified in the `ClientSession` constructor. 3. `aiohttp.MsgType` dropped, use `aiohttp.WSMsgType` instead. 4. `ClientResponse.url` is an instance of `yarl.URL` class (`url_obj` is deprecated) 5. `ClientResponse.raise_for_status()` raises :exc:`aiohttp.ClientResponseError` exception 6. `ClientResponse.json()` is strict about response's content type. if content type does not match, it raises :exc:`aiohttp.ClientResponseError` exception. To disable content type check you can pass ``None`` as `content_type` parameter. Server ------ ServerHttpProtocol and low-level details ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Internal implementation was significantly redesigned to provide better performance and support HTTP pipelining. ServerHttpProtocol is dropped, implementation is merged with RequestHandler a lot of low-level api's are dropped. Application ^^^^^^^^^^^ 1. Constructor parameter `loop` is deprecated. Loop is get configured by application runner, `run_app` function for any of gunicorn workers. 2. `Application.router.add_subapp` is dropped, use `Application.add_subapp` instead 3. `Application.finished` is dropped, use `Application.cleanup` instead WebRequest and WebResponse ^^^^^^^^^^^^^^^^^^^^^^^^^^ 1. the `GET` and `POST` attributes no longer exist. Use the `query` attribute instead of `GET` 2. Custom chunking size is not support `WebResponse.chunked` - developer is responsible for actual chunking. 3. Payloads are supported as body. So it is possible to use client response's content object as body parameter for `WebResponse` 4. `FileSender` api is dropped, it is replaced with more general `FileResponse` class:: async def handle(request): return web.FileResponse('path-to-file.txt') 5. `WebSocketResponse.protocol` is renamed to `WebSocketResponse.ws_protocol`. `WebSocketResponse.protocol` is instance of `RequestHandler` class. RequestPayloadError ^^^^^^^^^^^^^^^^^^^ Reading request's payload may raise a `RequestPayloadError` exception. The behavior is similar to `ClientPayloadError`. WSGI ^^^^ *WSGI* support has been dropped, as well as gunicorn wsgi support. We still provide default and uvloop gunicorn workers for `web.Application` ================================================ FILE: docs/misc.rst ================================================ .. _aiohttp-misc: Miscellaneous ============= Helpful pages. .. toctree:: :name: misc essays glossary .. toctree:: :titlesonly: changes Indices and tables ------------------ * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/multipart.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-multipart: Working with Multipart ====================== ``aiohttp`` supports a full featured multipart reader and writer. Both are designed with streaming processing in mind to avoid unwanted footprint which may be significant if you're dealing with large payloads, but this also means that most I/O operation are only possible to be executed a single time. Reading Multipart Responses --------------------------- Assume you made a request, as usual, and want to process the response multipart data:: async with aiohttp.request(...) as resp: pass First, you need to wrap the response with a :meth:`MultipartReader.from_response`. This needs to keep the implementation of :class:`MultipartReader` separated from the response and the connection routines which makes it more portable:: reader = aiohttp.MultipartReader.from_response(resp) Let's assume with this response you'd received some JSON document and multiple files for it, but you don't need all of them, just a specific one. So first you need to enter into a loop where the multipart body will be processed:: metadata = None filedata = None while True: part = await reader.next() The returned type depends on what the next part is: if it's a simple body part then you'll get :class:`BodyPartReader` instance here, otherwise, it will be another :class:`MultipartReader` instance for the nested multipart. Remember, that multipart format is recursive and supports multiple levels of nested body parts. When there are no more parts left to fetch, ``None`` value will be returned - that's the signal to break the loop:: if part is None: break Both :class:`BodyPartReader` and :class:`MultipartReader` provides access to body part headers: this allows you to filter parts by their attributes:: if part.headers[aiohttp.hdrs.CONTENT_TYPE] == 'application/json': metadata = await part.json() continue Neither :class:`BodyPartReader` nor :class:`MultipartReader` instances read the whole body part data without explicitly asking for. :class:`BodyPartReader` provides a set of helpers methods to fetch popular content types in friendly way: - :meth:`BodyPartReader.text` for plain text data; - :meth:`BodyPartReader.json` for JSON; - :meth:`BodyPartReader.form` for `application/www-urlform-encode` Each of these methods automatically recognizes if content is compressed by using `gzip` and `deflate` encoding (while it respects `identity` one), or if transfer encoding is base64 or `quoted-printable` - in each case the result will get automatically decoded. But in case you need to access to raw binary data as it is, there are :meth:`BodyPartReader.read` and :meth:`BodyPartReader.read_chunk` coroutine methods as well to read raw binary data as it is all-in-single-shot or by chunks respectively. When you have to deal with multipart files, the :attr:`BodyPartReader.filename` property comes to help. It's a very smart helper which handles `Content-Disposition` handler right and extracts the right filename attribute from it:: if part.filename != 'secret.txt': continue If current body part does not matches your expectation and you want to skip it - just continue a loop to start a next iteration of it. Here is where magic happens. Before fetching the next body part ``await reader.next()`` it ensures that the previous one was read completely. If it was not, all its content sends to the void in term to fetch the next part. So you don't have to care about cleanup routines while you're within a loop. Once you'd found a part for the file you'd searched for, just read it. Let's handle it as it is without applying any decoding magic:: filedata = await part.read(decode=False) Later you may decide to decode the data. It's still simple and possible to do:: filedata = part.decode(filedata) Once you are done with multipart processing, just break a loop:: break Sending Multipart Requests -------------------------- :class:`MultipartWriter` provides an interface to build multipart payload from the Python data and serialize it into chunked binary stream. Since multipart format is recursive and supports deeply nesting, you can use ``with`` statement to design your multipart data closer to how it will be:: with aiohttp.MultipartWriter('mixed') as mpwriter: ... with aiohttp.MultipartWriter('related') as subwriter: ... mpwriter.append(subwriter) with aiohttp.MultipartWriter('related') as subwriter: ... with aiohttp.MultipartWriter('related') as subsubwriter: ... subwriter.append(subsubwriter) mpwriter.append(subwriter) with aiohttp.MultipartWriter('related') as subwriter: ... mpwriter.append(subwriter) The :meth:`MultipartWriter.append` is used to join new body parts into a single stream. It accepts various inputs and determines what default headers should be used for. For text data default `Content-Type` is :mimetype:`text/plain; charset=utf-8`:: mpwriter.append('hello') For binary data :mimetype:`application/octet-stream` is used:: mpwriter.append(b'aiohttp') You can always override these default by passing your own headers with the second argument:: mpwriter.append(io.BytesIO(b'GIF89a...'), {'CONTENT-TYPE': 'image/gif'}) For file objects `Content-Type` will be determined by using Python's mod:`mimetypes` module and additionally `Content-Disposition` header will include the file's basename:: part = root.append(open(__file__, 'rb')) If you want to send a file with a different name, just handle the :class:`~aiohttp.payload.Payload` instance which :meth:`MultipartWriter.append` will always return and set `Content-Disposition` explicitly by using the :meth:`Payload.set_content_disposition() ` helper:: part.set_content_disposition('attachment', filename='secret.txt') Additionally, you may want to set other headers here:: part.headers[aiohttp.hdrs.CONTENT_ID] = 'X-12345' If you'd set `Content-Encoding`, it will be automatically applied to the data on serialization (see below):: part.headers[aiohttp.hdrs.CONTENT_ENCODING] = 'gzip' There are also :meth:`MultipartWriter.append_json` and :meth:`MultipartWriter.append_form` helpers which are useful to work with JSON and form urlencoded data, so you don't have to encode it every time manually:: mpwriter.append_json({'test': 'passed'}) mpwriter.append_form([('key', 'value')]) When it's done, to make a request just pass a root :class:`MultipartWriter` instance as :meth:`aiohttp.ClientSession.request` ``data`` argument:: await session.post('http://example.com', data=mpwriter) Behind the scenes :meth:`MultipartWriter.write` will yield chunks of every part and if body part has `Content-Encoding` or `Content-Transfer-Encoding` they will be applied on streaming content. Please note, that on :meth:`MultipartWriter.write` all the file objects will be read until the end and there is no way to repeat a request without rewinding their pointers to the start. Example MJPEG Streaming ``multipart/x-mixed-replace``. By default :meth:`MultipartWriter.write` appends closing ``--boundary--`` and breaks your content. Providing `close_boundary = False` prevents this.:: my_boundary = 'some-boundary' response = web.StreamResponse( status=200, reason='OK', headers={ 'Content-Type': 'multipart/x-mixed-replace;boundary={}'.format(my_boundary) } ) while True: frame = get_jpeg_frame() with MultipartWriter('image/jpeg', boundary=my_boundary) as mpwriter: mpwriter.append(frame, { 'Content-Type': 'image/jpeg' }) await mpwriter.write(response, close_boundary=False) await response.drain() Hacking Multipart ----------------- The Internet is full of terror and sometimes you may find a server which implements multipart support in strange ways when an oblivious solution does not work. For instance, is server used :class:`cgi.FieldStorage` then you have to ensure that no body part contains a `Content-Length` header:: for part in mpwriter: part.headers.pop(aiohttp.hdrs.CONTENT_LENGTH, None) On the other hand, some server may require to specify `Content-Length` for the whole multipart request. `aiohttp` does not do that since it sends multipart using chunked transfer encoding by default. To overcome this issue, you have to serialize a :class:`MultipartWriter` by our own in the way to calculate its size:: class Writer: def __init__(self): self.buffer = bytearray() async def write(self, data): self.buffer.extend(data) writer = Writer() await mpwriter.write(writer) await aiohttp.post('http://example.com', data=writer.buffer, headers=mpwriter.headers) Sometimes the server response may not be well formed: it may or may not contains nested parts. For instance, we request a resource which returns JSON documents with the files attached to it. If the document has any attachments, they are returned as a nested multipart. If it has not it responds as plain body parts: .. code-block:: none CONTENT-TYPE: multipart/mixed; boundary=--: --: CONTENT-TYPE: application/json {"_id": "foo"} --: CONTENT-TYPE: multipart/related; boundary=----: ----: CONTENT-TYPE: application/json {"_id": "bar"} ----: CONTENT-TYPE: text/plain CONTENT-DISPOSITION: attachment; filename=bar.txt bar! bar! bar! ----:-- --: CONTENT-TYPE: application/json {"_id": "boo"} --: CONTENT-TYPE: multipart/related; boundary=----: ----: CONTENT-TYPE: application/json {"_id": "baz"} ----: CONTENT-TYPE: text/plain CONTENT-DISPOSITION: attachment; filename=baz.txt baz! baz! baz! ----:-- --:-- Reading such kind of data in single stream is possible, but is not clean at all:: result = [] while True: part = await reader.next() if part is None: break if isinstance(part, aiohttp.MultipartReader): # Fetching files while True: filepart = await part.next() if filepart is None: break result[-1].append((await filepart.read())) else: # Fetching document result.append([(await part.json())]) Let's hack a reader in the way to return pairs of document and reader of the related files on each iteration:: class PairsMultipartReader(aiohttp.MultipartReader): # keep reference on the original reader multipart_reader_cls = aiohttp.MultipartReader async def next(self): """Emits a tuple of document object (:class:`dict`) and multipart reader of the followed attachments (if any). :rtype: tuple """ reader = await super().next() if self._at_eof: return None, None if isinstance(reader, self.multipart_reader_cls): part = await reader.next() doc = await part.json() else: doc = await reader.json() return doc, reader And this gives us a more cleaner solution:: reader = PairsMultipartReader.from_response(resp) result = [] while True: doc, files_reader = await reader.next() if doc is None: break files = [] while True: filepart = await files_reader.next() if file.part is None: break files.append((await filepart.read())) result.append((doc, files)) .. seealso:: :ref:`aiohttp-multipart-reference` ================================================ FILE: docs/multipart_reference.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-multipart-reference: Multipart reference =================== .. class:: MultipartResponseWrapper(resp, stream) :canonical: aiohttp.multipart.MultipartResponseWrapper Wrapper around the :class:`MultipartReader` to take care about underlying connection and close it when it needs in. .. method:: at_eof() Returns ``True`` when all response data had been read. :rtype: bool .. method:: next() :async: Emits next multipart reader object. .. method:: release() :async: Releases the connection gracefully, reading all the content to the void. .. class:: BodyPartReader(boundary, headers, content) :canonical: aiohttp.multipart.BodyPartReader Multipart reader for single body part. .. method:: read(*, decode=False) :async: Reads body part data. :param bool decode: Decodes data following by encoding method from ``Content-Encoding`` header. If it missed data remains untouched :rtype: bytearray .. method:: read_chunk(size=chunk_size) :async: Reads body part content chunk of the specified size. :param int size: chunk size :rtype: bytearray .. method:: readline() :async: Reads body part by line by line. :rtype: bytearray .. method:: release() :async: Like :meth:`read`, but reads all the data to the void. :rtype: None .. method:: text(*, encoding=None) :async: Like :meth:`read`, but assumes that body part contains text data. :param str encoding: Custom text encoding. Overrides specified in charset param of ``Content-Type`` header :rtype: str .. method:: json(*, encoding=None) :async: Like :meth:`read`, but assumes that body parts contains JSON data. :param str encoding: Custom JSON encoding. Overrides specified in charset param of ``Content-Type`` header .. method:: form(*, encoding=None) :async: Like :meth:`read`, but assumes that body parts contains form urlencoded data. :param str encoding: Custom form encoding. Overrides specified in charset param of ``Content-Type`` header .. method:: at_eof() Returns ``True`` if the boundary was reached or ``False`` otherwise. :rtype: bool .. method:: decode(data) Decodes data synchronously according the specified ``Content-Encoding`` or ``Content-Transfer-Encoding`` headers value. Supports ``gzip``, ``deflate`` and ``identity`` encodings for ``Content-Encoding`` header. Supports ``base64``, ``quoted-printable``, ``binary`` encodings for ``Content-Transfer-Encoding`` header. :param bytearray data: Data to decode. :raises: :exc:`RuntimeError` - if encoding is unknown. :rtype: bytes .. note:: For large payloads, consider using :meth:`decode_iter` instead to avoid blocking the event loop during decompression. .. method:: decode_iter(data) :async: Decodes data asynchronously according the specified ``Content-Encoding`` or ``Content-Transfer-Encoding`` headers value. This is an async iterator and will return decoded data in chunks. This can be used to avoid loading large payloads into memory. This method offloads decompression to an executor for large payloads to avoid blocking the event loop. Supports ``gzip``, ``deflate`` and ``identity`` encodings for ``Content-Encoding`` header. Supports ``base64``, ``quoted-printable``, ``binary`` encodings for ``Content-Transfer-Encoding`` header. :param bytearray data: Data to decode. :raises: :exc:`RuntimeError` - if encoding is unknown. :rtype: bytes .. versionadded:: 3.13.4 .. method:: get_charset(default=None) Returns charset parameter from ``Content-Type`` header or default. .. attribute:: name A field *name* specified in ``Content-Disposition`` header or ``None`` if missed or header is malformed. Readonly :class:`str` property. .. attribute:: filename A field *filename* specified in ``Content-Disposition`` header or ``None`` if missed or header is malformed. Readonly :class:`str` property. .. class:: MultipartReader(headers, content) :canonical: aiohttp.multipart.MultipartReader Multipart body reader. .. classmethod:: from_response(cls, response) Constructs reader instance from HTTP response. :param response: :class:`~aiohttp.ClientResponse` instance .. method:: at_eof() Returns ``True`` if the final boundary was reached or ``False`` otherwise. :rtype: bool .. method:: next() :async: Emits the next multipart body part. .. method:: release() :async: Reads all the body parts to the void till the final boundary. .. method:: fetch_next_part() :async: Returns the next body part reader. .. class:: MultipartWriter(subtype='mixed', boundary=None, close_boundary=True) :canonical: aiohttp.multipart.MultipartWriter Multipart body writer. ``boundary`` may be an ASCII-only string. .. attribute:: boundary The string (:class:`str`) representation of the boundary. .. versionchanged:: 3.0 Property type was changed from :class:`bytes` to :class:`str`. .. method:: append(obj, headers=None) Append an object to writer. .. method:: append_payload(payload) Adds a new body part to multipart writer. .. method:: append_json(obj, headers=None) Helper to append JSON part. .. method:: append_form(obj, headers=None) Helper to append form urlencoded part. .. attribute:: size Size of the payload. .. method:: write(writer, close_boundary=True) :async: Write body. :param bool close_boundary: The (:class:`bool`) that will emit boundary closing. You may want to disable when streaming (``multipart/x-mixed-replace``) .. versionadded:: 3.4 Support ``close_boundary`` argument. ================================================ FILE: docs/new_router.rst ================================================ .. _aiohttp-router-refactoring-021: Router refactoring in 0.21 ========================== Rationale --------- First generation (v1) of router has mapped ``(method, path)`` pair to :term:`web-handler`. Mapping is named **route**. Routes used to have unique names if any. The main mistake with the design is coupling the **route** to ``(method, path)`` pair while really URL construction operates with **resources** (**location** is a synonym). HTTP method is not part of URI but applied on sending HTTP request only. Having different **route names** for the same path is confusing. Moreover **named routes** constructed for the same path should have unique non overlapping names which is cumbersome is certain situations. From other side sometimes it's desirable to bind several HTTP methods to the same web handler. For *v1* router it can be solved by passing '*' as HTTP method. Class based views require '*' method also usually. Implementation -------------- The change introduces **resource** as first class citizen:: resource = router.add_resource('/path/{to}', name='name') *Resource* has a **path** (dynamic or constant) and optional **name**. The name is **unique** in router context. *Resource* has **routes**. *Route* corresponds to *HTTP method* and :term:`web-handler` for the method:: route = resource.add_route('GET', handler) User still may use wildcard for accepting all HTTP methods (maybe we will add something like ``resource.add_wildcard(handler)`` later). Since **names** belongs to **resources** now ``app.router['name']`` returns a **resource** instance instead of :class:`aiohttp.web.AbstractRoute`. **resource** has ``.url()`` method, so ``app.router['name'].url(parts={'a': 'b'}, query={'arg': 'param'})`` still works as usual. The change allows to rewrite static file handling and implement nested applications as well. Decoupling of *HTTP location* and *HTTP method* makes life easier. Backward compatibility ---------------------- The refactoring is 99% compatible with previous implementation. 99% means all example and the most of current code works without modifications but we have subtle API backward incompatibles. ``app.router['name']`` returns a :class:`aiohttp.web.AbstractResource` instance instead of :class:`aiohttp.web.AbstractRoute` but resource has the same ``resource.url(...)`` most useful method, so end user should feel no difference. ``route.match(...)`` is **not** supported anymore, use :meth:`aiohttp.web.AbstractResource.resolve` instead. ``app.router.add_route(method, path, handler, name='name')`` now is just shortcut for:: resource = app.router.add_resource(path, name=name) route = resource.add_route(method, handler) return route ``app.router.register_route(...)`` is still supported, it creates ``aiohttp.web.ResourceAdapter`` for every call (but it's deprecated now). ================================================ FILE: docs/powered_by.rst ================================================ .. _aiohttp-powered-by: Powered by aiohttp ================== Web sites powered by aiohttp. Feel free to fork documentation on github, add a link to your site and make a Pull Request! * `Farmer Business Network `_ * `Home Assistant `_ * `KeepSafe `_ * `Skyscanner Hotels `_ * `Ocean S.A. `_ * `GNS3 `_ * `TutorCruncher socket `_ * `Eyepea - Custom telephony solutions `_ * `ALLOcloud - Telephony in the cloud `_ * `helpmanual - comprehensive help and man page database `_ * `bedevere `_ - CPython's GitHub bot, helps maintain and identify issues with a CPython pull request. * `miss-islington `_ - CPython's GitHub bot, backports and merge CPython's pull requests * `noa technologies - Bike-sharing management platform `_ - SSE endpoint, pushes real time updates of bikes location. * `Wargaming: World of Tanks `_ * `Yandex `_ * `Rambler `_ * `Escargot `_ - Chat server * `Prom.ua `_ - Online trading platform * `globo.com `_ - (some parts) Brazilian largest media portal * `Glose `_ - Social reader for E-Books * `Emoji Generator `_ - Text icon generator * `SerpsBot Google Search API `_ - SerpsBot Google Search API * `PyChess `_ - Chess variant server ================================================ FILE: docs/spelling_wordlist.txt ================================================ abc addons aiodns aioes aiohttp aiohttpdemo aiohttp’s aiopg al alives api api’s app app’s apps arg args armv Arsenic async asyncio asyncpg asynctest attrs auth autocalculated autodetection autoformatter autoformatters autogenerates autogeneration awaitable backoff backend backends backport Backport Backporting backports BaseEventLoop basename BasicAuth behaviour BodyPartReader boolean botocore brotli Brotli brotlicffi brotlipy bugfix bugfixes Bugfixes builtin BytesIO callables cancelled canonicalization canonicalize cchardet cChardet ceil changelog Changelog chardet Chardet charset charsetdetect chunked chunking CIMultiDict ClientSession cls cmd codebase codec Codings committer committers config Config configs conjunction contextmanager CookieJar coroutine Coroutine coroutines cpu CPython css ctor Ctrl cython Cython Cythonize cythonized de deduplicate defs Dependabot deprecations deserialization DER dev Dev dict Dict Discord django Django dns DNSResolver docstring docstrings DoS downstreams Dup elasticsearch encodings env environ eof epoll et etag ETag expirations Facebook facto fallback fallbacks filename finalizers formatter formatters frontend getall gethostbyname github google gunicorn gunicorn’s gzipped hackish highlevel hostnames HTTPException HttpProcessingError httpretty https hostname impl incapsulates Indices infos initializer inline intaking io IoT ip IP ipdb ipv IPv ish isort iterable iterables javascript Jinja jitter json keepalive keepalived keepalives keepaliving kib KiB kwarg kwargs latin lifecycle linux llhttp localhost Locator login lookup lookups lossless lowercased Mako manylinux metadata MiB microservice middleware middlewares miltidict misbehaviors Mixcloud Mongo msg MsgType multi multidict multidict’s multidicts Multidicts multipart Multipart musllinux mypy Nagle Nagle's NFS namedtuple nameservers namespace netrc nginx Nginx Nikolay noop normalizer nowait OAuth Online optimizations orjson os outcoming Overridable Paolini param params parsers pathlib payloads peername performant pickleable ping pipelining pluggable plugin poller pong Postgres pre preloaded proactor programmatically proxied PRs pubsub Punycode py pydantic pyenv pyflakes pyright pytest Pytest qop Quickstart quickstart quote’s rc readline readonly readpayload rebase redirections Redis refactor refactored refactoring referenceable regex regexps regexs reloader renderer renderers repo repr repr’s RequestContextManager request’s Request’s requote requoting resolvehost resolvers reusage reuseconn riscv64 Runit runtime runtimes sa Satisfiable scalability schemas seekable sendfile serializable serializer shourtcuts skipuntil Skyscanner SocketSocketTransport ssl SSLContext startup stateful storages subapplication subclassed subclasses subdirectory submodules subpackage subprotocol subprotocols subtype supervisord Supervisord Svetlov symlink symlinks syscall syscalls Systemd tarball TCP teardown Teardown TestClient Testsuite Tf timestamps TLS tmp tmpdir toolbar toplevel towncrier tp tuples UI un unawaited unclosed undercounting unescaped unhandled unicode unittest Unittest unpickler untrusted unix unobvious unsets unstripped untyped uppercased upstr url urldispatcher urlencoded url’s urls utf utils uvloop uWSGI vcvarsall vendored vendoring waituntil wakeup wakeups webapp websocket websocket’s websockets Websockets wildcard Winloop Workflow ws wsgi wss www xxx yarl zlib zstandard zstd ================================================ FILE: docs/streams.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-streams: Streaming API ============= ``aiohttp`` uses streams for retrieving *BODIES*: :attr:`aiohttp.web.BaseRequest.content` and :attr:`aiohttp.ClientResponse.content` are properties with stream API. .. class:: StreamReader :canonical: aiohttp.streams.StreamReader The reader from incoming stream. User should never instantiate streams manually but use existing :attr:`aiohttp.web.BaseRequest.content` and :attr:`aiohttp.ClientResponse.content` properties for accessing raw BODY data. Reading Attributes and Methods ------------------------------ .. method:: StreamReader.read(n=-1) :async: Read up to a maximum of *n* bytes. If *n* is not provided, or set to ``-1``, read until EOF and return all read bytes. When *n* is provided, data will be returned as soon as it is available. Therefore it will return less than *n* bytes if there are less than *n* bytes in the buffer. If the EOF was received and the internal buffer is empty, return an empty bytes object. :param int n: maximum number of bytes to read, ``-1`` for the whole stream. :return bytes: the given data .. method:: StreamReader.readany() :async: Read next data portion for the stream. Returns immediately if internal buffer has a data. :return bytes: the given data .. method:: StreamReader.readexactly(n) :async: Read exactly *n* bytes. Raise an :exc:`asyncio.IncompleteReadError` if the end of the stream is reached before *n* can be read, the :attr:`asyncio.IncompleteReadError.partial` attribute of the exception contains the partial read bytes. :param int n: how many bytes to read. :return bytes: the given data .. method:: StreamReader.readline() :async: Read one line, where “line” is a sequence of bytes ending with ``\n``. If EOF is received, and ``\n`` was not found, the method will return the partial read bytes. If the EOF was received and the internal buffer is empty, return an empty bytes object. :return bytes: the given line .. method:: StreamReader.readuntil(separator="\n") :async: Read until separator, where `separator` is a sequence of bytes. If EOF is received, and `separator` was not found, the method will return the partial read bytes. If the EOF was received and the internal buffer is empty, return an empty bytes object. .. versionadded:: 3.8 :return bytes: the given data .. method:: StreamReader.readchunk() :async: Read a chunk of data as it was received by the server. Returns a tuple of (data, end_of_HTTP_chunk). When chunked transfer encoding is used, end_of_HTTP_chunk is a :class:`bool` indicating if the end of the data corresponds to the end of a HTTP chunk, otherwise it is always ``False``. :return tuple[bytes, bool]: a chunk of data and a :class:`bool` that is ``True`` when the end of the returned chunk corresponds to the end of a HTTP chunk. .. attribute:: StreamReader.total_raw_bytes The number of bytes of raw data downloaded (before decompression). Readonly :class:`int` property. Asynchronous Iteration Support ------------------------------ Stream reader supports asynchronous iteration over BODY. By default it iterates over lines:: async for line in response.content: print(line) Also there are methods for iterating over data chunks with maximum size limit and over any available data. .. method:: StreamReader.iter_chunked(n) :async: Iterates over data chunks with maximum size limit:: async for data in response.content.iter_chunked(1024): print(data) To get chunks that are exactly *n* bytes, you could use the `asyncstdlib.itertools `_ module:: chunks = batched(chain.from_iterable(response.content.iter_chunked(n)), n) async for data in chunks: print(data) .. method:: StreamReader.iter_any() :async: Iterates over data chunks in order of intaking them into the stream:: async for data in response.content.iter_any(): print(data) .. method:: StreamReader.iter_chunks() :async: Iterates over data chunks as received from the server:: async for data, _ in response.content.iter_chunks(): print(data) If chunked transfer encoding is used, the original http chunks formatting can be retrieved by reading the second element of returned tuples:: buffer = b"" async for data, end_of_http_chunk in response.content.iter_chunks(): buffer += data if end_of_http_chunk: print(buffer) buffer = b"" Helpers ------- .. method:: StreamReader.exception() Get the exception occurred on data reading. .. method:: is_eof() Return ``True`` if EOF was reached. Internal buffer may be not empty at the moment. .. seealso:: :meth:`StreamReader.at_eof` .. method:: StreamReader.at_eof() Return ``True`` if the buffer is empty and EOF was reached. .. method:: StreamReader.read_nowait(n=None) Returns data from internal buffer if any, empty bytes object otherwise. Raises :exc:`RuntimeError` if other coroutine is waiting for stream. :param int n: how many bytes to read, ``-1`` for the whole internal buffer. :return bytes: the given data .. method:: StreamReader.unread_data(data) Rollback reading some data from stream, inserting it to buffer head. :param bytes data: data to push back into the stream. .. warning:: The method does not wake up waiters. E.g. :meth:`~StreamReader.read` will not be resumed. .. method:: wait_eof() :async: Wait for EOF. The given data may be accessible by upcoming read calls. ================================================ FILE: docs/structures.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-structures: Common data structures ====================== Common data structures used by *aiohttp* internally. FrozenList ---------- A list-like structure which implements :class:`collections.abc.MutableSequence`. The list is *mutable* unless :meth:`FrozenList.freeze` is called, after that the list modification raises :exc:`RuntimeError`. .. class:: FrozenList(items) Construct a new *non-frozen* list from *items* iterable. The list implements all :class:`collections.abc.MutableSequence` methods plus two additional APIs. .. attribute:: frozen A read-only property, ``True`` is the list is *frozen* (modifications are forbidden). .. method:: freeze() Freeze the list. There is no way to *thaw* it back. ChainMapProxy ------------- An *immutable* version of :class:`collections.ChainMap`. Internally the proxy is a list of mappings (dictionaries), if the requested key is not present in the first mapping the second is looked up and so on. The class supports :class:`collections.abc.Mapping` interface. .. class:: ChainMapProxy(maps) Create a new chained mapping proxy from a list of mappings (*maps*). .. versionadded:: 3.2 ================================================ FILE: docs/testing.rst ================================================ .. module:: aiohttp.test_utils .. _aiohttp-testing: Testing ======= Testing aiohttp web servers --------------------------- aiohttp provides plugin for *pytest* making writing web server tests extremely easy, it also provides :ref:`test framework agnostic utilities ` for testing with other frameworks such as :ref:`unittest `. Before starting to write your tests, you may also be interested on reading :ref:`how to write testable services` that interact with the loop. For using pytest plugin please install pytest-aiohttp_ library: .. code-block:: shell $ pip install pytest-aiohttp If you don't want to install *pytest-aiohttp* for some reason you may insert ``pytest_plugins = 'aiohttp.pytest_plugin'`` line into ``conftest.py`` instead for the same functionality. The Test Client and Servers ~~~~~~~~~~~~~~~~~~~~~~~~~~~ *aiohttp* test utils provides a scaffolding for testing aiohttp-based web servers. They consist of two parts: running test server and making HTTP requests to this server. :class:`~aiohttp.test_utils.TestServer` runs :class:`aiohttp.web.Application` based server, :class:`~aiohttp.test_utils.RawTestServer` starts :class:`aiohttp.web.Server` low level server. For performing HTTP requests to these servers you have to create a test client: :class:`~aiohttp.test_utils.TestClient` instance. The client incapsulates :class:`aiohttp.ClientSession` by providing proxy methods to the client for common operations such as *ws_connect*, *get*, *post*, etc. Pytest ~~~~~~ .. currentmodule:: pytest_aiohttp The :data:`aiohttp_client` fixture available from pytest-aiohttp_ plugin allows you to create a client to make requests to test your app. To run these examples, you need to use `--asyncio-mode=auto` or add to your pytest config file:: asyncio_mode = auto A simple test would be:: from aiohttp import web async def hello(request): return web.Response(text='Hello, world') async def test_hello(aiohttp_client): app = web.Application() app.router.add_get('/', hello) client = await aiohttp_client(app) resp = await client.get('/') assert resp.status == 200 text = await resp.text() assert 'Hello, world' in text It also provides access to the app instance allowing tests to check the state of the app. Tests can be made even more succinct with a fixture to create an app test client:: import pytest from aiohttp import web value = web.AppKey("value", str) async def previous(request): if request.method == 'POST': request.app[value] = (await request.post())['value'] return web.Response(body=b'thanks for the data') return web.Response( body='value: {}'.format(request.app[value]).encode('utf-8')) @pytest.fixture async def cli(aiohttp_client): app = web.Application() app.router.add_get('/', previous) app.router.add_post('/', previous) return await aiohttp_client(app) async def test_set_value(cli): resp = await cli.post('/', data={'value': 'foo'}) assert resp.status == 200 assert await resp.text() == 'thanks for the data' assert cli.server.app[value] == 'foo' async def test_get_value(cli): cli.server.app[value] = 'bar' resp = await cli.get('/') assert resp.status == 200 assert await resp.text() == 'value: bar' Pytest tooling has the following fixtures: .. data:: aiohttp_server(app, *, port=None, **kwargs) A fixture factory that creates :class:`~aiohttp.test_utils.TestServer`:: async def test_f(aiohttp_server): app = web.Application() # fill route table server = await aiohttp_server(app) The server will be destroyed on exit from test function. *app* is the :class:`aiohttp.web.Application` used to start server. *port* optional, port the server is run at, if not provided a random unused port is used. .. versionadded:: 3.0 *kwargs* are parameters passed to :meth:`aiohttp.web.AppRunner` .. versionchanged:: 3.0 .. deprecated:: 3.2 The fixture was renamed from ``test_server`` to ``aiohttp_server``. .. data:: aiohttp_client(app, server_kwargs=None, **kwargs) aiohttp_client(server, **kwargs) aiohttp_client(raw_server, **kwargs) A fixture factory that creates :class:`~aiohttp.test_utils.TestClient` for access to tested server:: async def test_f(aiohttp_client): app = web.Application() # fill route table client = await aiohttp_client(app) resp = await client.get('/') *client* and responses are cleaned up after test function finishing. The fixture accepts :class:`aiohttp.web.Application`, :class:`aiohttp.test_utils.TestServer` or :class:`aiohttp.test_utils.RawTestServer` instance. *server_kwargs* are parameters passed to the test server if an app is passed, else ignored. *kwargs* are parameters passed to :class:`aiohttp.test_utils.TestClient` constructor. .. versionchanged:: 3.0 The fixture was renamed from ``test_client`` to ``aiohttp_client``. .. data:: aiohttp_raw_server(handler, *, port=None, **kwargs) A fixture factory that creates :class:`~aiohttp.test_utils.RawTestServer` instance from given web handler.:: async def test_f(aiohttp_raw_server, aiohttp_client): async def handler(request): return web.Response(text="OK") raw_server = await aiohttp_raw_server(handler) client = await aiohttp_client(raw_server) resp = await client.get('/') *handler* should be a coroutine which accepts a request and returns response, e.g. *port* optional, port the server is run at, if not provided a random unused port is used. .. versionadded:: 3.0 .. data:: aiohttp_unused_port() Function to return an unused port number for IPv4 TCP protocol:: async def test_f(aiohttp_client, aiohttp_unused_port): port = aiohttp_unused_port() app = web.Application() # fill route table client = await aiohttp_client(app, server_kwargs={'port': port}) ... .. versionchanged:: 3.0 The fixture was renamed from ``unused_port`` to ``aiohttp_unused_port``. .. data:: aiohttp_client_cls A fixture for passing custom :class:`~aiohttp.test_utils.TestClient` implementations:: class MyClient(TestClient): async def login(self, *, user, pw): payload = {"username": user, "password": pw} return await self.post("/login", json=payload) @pytest.fixture def aiohttp_client_cls(): return MyClient def test_login(aiohttp_client): app = web.Application() client = await aiohttp_client(app) await client.login(user="admin", pw="s3cr3t") If you want to switch between different clients in tests, you can use the usual ``pytest`` machinery. Example with using test markers:: class RESTfulClient(TestClient): ... class GraphQLClient(TestClient): ... @pytest.fixture def aiohttp_client_cls(request): if request.node.get_closest_marker('rest') is not None: return RESTfulClient if request.node.get_closest_marker('graphql') is not None: return GraphQLClient return TestClient @pytest.mark.rest async def test_rest(aiohttp_client) -> None: client: RESTfulClient = await aiohttp_client(Application()) ... @pytest.mark.graphql async def test_graphql(aiohttp_client) -> None: client: GraphQLClient = await aiohttp_client(Application()) ... .. _aiohttp-testing-unittest-example: .. _aiohttp-testing-unittest-style: Unittest ~~~~~~~~ .. currentmodule:: aiohttp.test_utils To test applications with the standard library's unittest or unittest-based functionality, the AioHTTPTestCase is provided:: from aiohttp.test_utils import AioHTTPTestCase from aiohttp import web class MyAppTestCase(AioHTTPTestCase): async def get_application(self): """ Override the get_app method to return your application. """ async def hello(request): return web.Response(text='Hello, world') app = web.Application() app.router.add_get('/', hello) return app async def test_example(self): async with self.client.request("GET", "/") as resp: self.assertEqual(resp.status, 200) text = await resp.text() self.assertIn("Hello, world", text) .. class:: AioHTTPTestCase A base class to allow for unittest web applications using aiohttp. Derived from :class:`unittest.IsolatedAsyncioTestCase` See :class:`unittest.TestCase` and :class:`unittest.IsolatedAsyncioTestCase` for inherited methods and behavior. This class additionally provides the following: .. attribute:: client an aiohttp test client, :class:`TestClient` instance. .. attribute:: server an aiohttp test server, :class:`TestServer` instance. .. versionadded:: 2.3 .. attribute:: app The application returned by :meth:`~aiohttp.test_utils.AioHTTPTestCase.get_application` (:class:`aiohttp.web.Application` instance). .. method:: get_client() :async: This async method can be overridden to return the :class:`TestClient` object used in the test. :return: :class:`TestClient` instance. .. versionadded:: 2.3 .. method:: get_server() :async: This async method can be overridden to return the :class:`TestServer` object used in the test. :return: :class:`TestServer` instance. .. versionadded:: 2.3 .. method:: get_application() :async: This async method should be overridden to return the :class:`aiohttp.web.Application` object to test. :return: :class:`aiohttp.web.Application` instance. .. method:: asyncSetUp() :async: This async method can be overridden to execute asynchronous code during the ``setUp`` stage of the ``TestCase``:: async def asyncSetUp(self): await super().asyncSetUp() await foo() .. versionadded:: 2.3 .. versionchanged:: 3.8 ``await super().asyncSetUp()`` call is required. .. method:: asyncTearDown() :async: This async method can be overridden to execute asynchronous code during the ``tearDown`` stage of the ``TestCase``:: async def asyncTearDown(self): await super().asyncTearDown() await foo() .. versionadded:: 2.3 .. versionchanged:: 3.8 ``await super().asyncTearDown()`` call is required. Faking request object ^^^^^^^^^^^^^^^^^^^^^ aiohttp provides test utility for creating fake :class:`aiohttp.web.Request` objects: :func:`aiohttp.test_utils.make_mocked_request`, it could be useful in case of simple unit tests, like handler tests, or simulate error conditions that hard to reproduce on real server:: from aiohttp import web from aiohttp.test_utils import make_mocked_request def handler(request): assert request.headers.get('token') == 'x' return web.Response(body=b'data') def test_handler(): req = make_mocked_request('GET', '/', headers={'token': 'x'}) resp = handler(req) assert resp.body == b'data' .. warning:: We don't recommend to apply :func:`~aiohttp.test_utils.make_mocked_request` everywhere for testing web-handler's business object -- please use test client and real networking via 'localhost' as shown in examples before. :func:`~aiohttp.test_utils.make_mocked_request` exists only for testing complex cases (e.g. emulating network errors) which are extremely hard or even impossible to test by conventional way. .. function:: make_mocked_request(method, path, headers=None, *, \ version=HttpVersion(1, 1), \ closing=False, \ app=None, \ match_info=sentinel, \ reader=sentinel, \ writer=sentinel, \ transport=sentinel, \ payload=sentinel, \ sslcontext=None, \ loop=...) Creates mocked web.Request testing purposes. Useful in unit tests, when spinning full web server is overkill or specific conditions and errors are hard to trigger. :param method: str, that represents HTTP method, like; GET, POST. :type method: str :param path: str, The URL including *PATH INFO* without the host or scheme :type path: str :param headers: mapping containing the headers. Can be anything accepted by the multidict.CIMultiDict constructor. :type headers: dict, multidict.CIMultiDict, list of tuple(str, str) :param match_info: mapping containing the info to match with url parameters. :type match_info: dict :param version: namedtuple with encoded HTTP version :type version: aiohttp.protocol.HttpVersion :param closing: flag indicates that connection should be closed after response. :type closing: bool :param app: the aiohttp.web application attached for fake request :type app: aiohttp.web.Application :param writer: object for managing outcoming data :type writer: aiohttp.StreamWriter :param transport: asyncio transport instance :type transport: asyncio.Transport :param payload: raw payload reader object :type payload: aiohttp.StreamReader :param sslcontext: ssl.SSLContext object, for HTTPS connection :type sslcontext: ssl.SSLContext :param loop: An event loop instance, mocked loop by default. :type loop: :class:`asyncio.AbstractEventLoop` :return: :class:`aiohttp.web.Request` object. .. versionadded:: 2.3 *match_info* parameter. .. _aiohttp-testing-writing-testable-services: .. _aiohttp-testing-framework-agnostic-utilities: Framework Agnostic Utilities ---------------------------- High level test creation:: from aiohttp.test_utils import TestClient, TestServer from aiohttp import request async def test(): app = _create_example_app() async with TestClient(TestServer(app)) as client: async def test_get_route(): nonlocal client resp = await client.get("/") assert resp.status == 200 text = await resp.text() assert "Hello, world" in text await test_get_route() If it's preferred to handle the creation / teardown on a more granular basis, the TestClient object can be used directly:: from aiohttp.test_utils import TestClient, TestServer async def test(): app = _create_example_app() client = TestClient(TestServer(app)) await client.start_server() root = "http://127.0.0.1:{}".format(port) async def test_get_route(): resp = await client.get("/") assert resp.status == 200 text = await resp.text() assert "Hello, world" in text await test_get_route() await client.close() A full list of the utilities provided can be found at the :data:`api reference ` Testing API Reference --------------------- Test server ~~~~~~~~~~~ Runs given :class:`aiohttp.web.Application` instance on random TCP port. After creation the server is not started yet, use :meth:`~aiohttp.test_utils.BaseTestServer.start_server` for actual server starting and :meth:`~aiohttp.test_utils.BaseTestServer.close` for stopping/cleanup. Test server usually works in conjunction with :class:`aiohttp.test_utils.TestClient` which provides handy client methods for accessing to the server. .. class:: BaseTestServer(*, scheme='http', host='127.0.0.1', port=None, socket_factory=get_port_socket) Base class for test servers. :param str scheme: HTTP scheme, non-protected ``"http"`` by default. :param str host: a host for TCP socket, IPv4 *local host* (``'127.0.0.1'``) by default. :param int port: optional port for TCP socket, if not provided a random unused port is used. .. versionadded:: 3.0 :param collections.abc.Callable[[str,int,socket.AddressFamily],socket.socket] socket_factory: optional Factory to create a socket for the server. By default creates a TCP socket and binds it to ``host`` and ``port``. .. versionadded:: 3.8 .. attribute:: scheme A *scheme* for tested application, ``'http'`` for non-protected run and ``'https'`` for TLS encrypted server. .. attribute:: host *host* used to start a test server. .. attribute:: port *port* used to start the test server. .. attribute:: handler :class:`aiohttp.web.Server` used for HTTP requests serving. .. attribute:: server :class:`asyncio.AbstractServer` used for managing accepted connections. .. attribute:: socket_factory *socket_factory* used to create and bind a server socket. .. versionadded:: 3.8 .. method:: start_server(**kwargs) :async: Start a test server. .. method:: close() :async: Stop and finish executed test server. .. method:: make_url(path) Return an *absolute* :class:`~yarl.URL` for given *path*. .. class:: RawTestServer(handler, *, scheme="http", host='127.0.0.1') Low-level test server (derived from :class:`BaseTestServer`). :param handler: a coroutine for handling web requests. The handler should accept :class:`aiohttp.web.BaseRequest` and return a response instance, e.g. :class:`~aiohttp.web.StreamResponse` or :class:`~aiohttp.web.Response`. The handler could raise :class:`~aiohttp.web.HTTPException` as a signal for non-200 HTTP response. :param str scheme: HTTP scheme, non-protected ``"http"`` by default. :param str host: a host for TCP socket, IPv4 *local host* (``'127.0.0.1'``) by default. :param int port: optional port for TCP socket, if not provided a random unused port is used. .. versionadded:: 3.0 .. class:: TestServer(app, *, scheme="http", host='127.0.0.1') Test server (derived from :class:`BaseTestServer`) for starting :class:`~aiohttp.web.Application`. :param app: :class:`aiohttp.web.Application` instance to run. :param str scheme: HTTP scheme, non-protected ``"http"`` by default. :param str host: a host for TCP socket, IPv4 *local host* (``'127.0.0.1'``) by default. :param int port: optional port for TCP socket, if not provided a random unused port is used. .. versionadded:: 3.0 .. attribute:: app :class:`aiohttp.web.Application` instance to run. Test Client ~~~~~~~~~~~ .. class:: TestClient(app_or_server, *, \ scheme='http', host='127.0.0.1', \ cookie_jar=None, **kwargs) A test client used for making calls to tested server. :param app_or_server: :class:`BaseTestServer` instance for making client requests to it. In order to pass an :class:`aiohttp.web.Application` you need to convert it first to :class:`TestServer` first with ``TestServer(app)``. :param cookie_jar: an optional :class:`aiohttp.CookieJar` instance, may be useful with ``CookieJar(unsafe=True, treat_as_secure_origin="http://127.0.0.1")`` option. :param str scheme: HTTP scheme, non-protected ``"http"`` by default. :param str host: a host for TCP socket, IPv4 *local host* (``'127.0.0.1'``) by default. .. attribute:: scheme A *scheme* for tested application, ``'http'`` for non-protected run and ``'https'`` for TLS encrypted server. .. attribute:: host *host* used to start a test server. .. attribute:: port *port* used to start the server .. attribute:: server :class:`BaseTestServer` test server instance used in conjunction with client. .. attribute:: app An alias for ``self.server.app``. return ``None`` if ``self.server`` is not :class:`TestServer` instance(e.g. :class:`RawTestServer` instance for test low-level server). .. attribute:: session An internal :class:`aiohttp.ClientSession`. Unlike the methods on the :class:`TestClient`, client session requests do not automatically include the host in the url queried, and will require an absolute path to the resource. .. method:: start_server(**kwargs) :async: Start a test server. .. method:: close() :async: Stop and finish executed test server. .. method:: make_url(path) Return an *absolute* :class:`~yarl.URL` for given *path*. .. method:: request(method, path, *args, **kwargs) :async: Routes a request to tested http server. The interface is identical to :meth:`aiohttp.ClientSession.request`, except the loop kwarg is overridden by the instance used by the test server. .. method:: get(path, *args, **kwargs) :async: Perform an HTTP GET request. .. method:: post(path, *args, **kwargs) :async: Perform an HTTP POST request. .. method:: options(path, *args, **kwargs) :async: Perform an HTTP OPTIONS request. .. method:: head(path, *args, **kwargs) :async: Perform an HTTP HEAD request. .. method:: put(path, *args, **kwargs) :async: Perform an HTTP PUT request. .. method:: patch(path, *args, **kwargs) :async: Perform an HTTP PATCH request. .. method:: delete(path, *args, **kwargs) :async: Perform an HTTP DELETE request. .. method:: ws_connect(path, *args, **kwargs) :async: Initiate websocket connection. The api corresponds to :meth:`aiohttp.ClientSession.ws_connect`. Utilities ~~~~~~~~~ .. function:: unused_port() Return an unused port number for IPv4 TCP protocol. :return int: ephemeral port number which could be reused by test server. .. function:: loop_context(loop_factory=) A contextmanager that creates an event_loop, for test purposes. Handles the creation and cleanup of a test loop. .. function:: setup_test_loop(loop_factory=) Create and return an :class:`asyncio.AbstractEventLoop` instance. The caller should also call teardown_test_loop, once they are done with the loop. .. note:: As side effect the function changes asyncio *default loop* by :func:`asyncio.set_event_loop` call. Previous default loop is not restored. It should not be a problem for test suite: every test expects a new test loop instance anyway. .. versionchanged:: 3.1 The function installs a created event loop as *default*. .. function:: teardown_test_loop(loop) Teardown and cleanup an event_loop created by setup_test_loop. :param loop: the loop to teardown :type loop: asyncio.AbstractEventLoop .. _pytest: http://pytest.org/latest/ .. _pytest-aiohttp: https://pypi.python.org/pypi/pytest-aiohttp ================================================ FILE: docs/third_party.rst ================================================ .. _aiohttp-3rd-party: Third-Party libraries ===================== aiohttp is not just a library for making HTTP requests and creating web servers. It is the foundation for libraries built *on top* of aiohttp. This page is a list of these tools. Please feel free to add your open source library if it's not listed yet by making a pull request to https://github.com/aio-libs/aiohttp/ * Why would you want to include your awesome library in this list? * Because the list increases your library visibility. People will have an easy way to find it. Officially supported -------------------- This list contains libraries which are supported by the *aio-libs* team and located on https://github.com/aio-libs aiohttp extensions ^^^^^^^^^^^^^^^^^^ - `aiohttp-apischema `_ provides automatic API schema generation and validation of user input for :mod:`aiohttp.web`. - `aiohttp-session `_ provides sessions for :mod:`aiohttp.web`. - `aiohttp-debugtoolbar `_ is a library for *debug toolbar* support for :mod:`aiohttp.web`. - `aiohttp-security `_ auth and permissions for :mod:`aiohttp.web`. - `aiohttp-devtools `_ provides development tools for :mod:`aiohttp.web` applications. - `aiohttp-cors `_ CORS support for aiohttp. - `aiohttp-sse `_ Server-sent events support for aiohttp. - `pytest-aiohttp `_ pytest plugin for aiohttp support. - `aiohttp-mako `_ Mako template renderer for aiohttp.web. - `aiohttp-jinja2 `_ Jinja2 template renderer for aiohttp.web. - `aiozipkin `_ distributed tracing instrumentation for `aiohttp` client and server. Database drivers ^^^^^^^^^^^^^^^^ - `aiopg `_ PostgreSQL async driver. - `aiomysql `_ MySQL async driver. - `aioredis `_ Redis async driver. Other tools ^^^^^^^^^^^ - `aiodocker `_ Python Docker API client based on asyncio and aiohttp. - `aiobotocore `_ asyncio support for botocore library using aiohttp. Approved third-party libraries ------------------------------ These libraries are not part of ``aio-libs`` but they have proven to be very well written and highly recommended for usage. - `uvloop `_ Ultra fast implementation of asyncio event loop on top of ``libuv``. We highly recommend to use this instead of standard ``asyncio``. Database drivers ^^^^^^^^^^^^^^^^ - `asyncpg `_ Another PostgreSQL async driver. It's much faster than ``aiopg`` but is not a drop-in replacement -- the API is different. But, please take a look at it -- the driver is incredibly fast. OpenAPI / Swagger extensions ---------------------------- Extensions bringing `OpenAPI `_ support to aiohttp web servers. - `aiohttp-apispec `_ Build and document REST APIs with ``aiohttp`` and ``apispec``. - `aiohttp_apiset `_ Package to build routes using swagger specification. - `aiohttp-pydantic `_ An ``aiohttp.View`` to validate the HTTP request's body, query-string, and headers regarding function annotations and generate OpenAPI doc. - `aiohttp-swagger `_ Swagger API Documentation builder for aiohttp server. - `aiohttp-swagger3 `_ Library for Swagger documentation builder and validating aiohttp requests using swagger specification 3.0. - `aiohttp-swaggerify `_ Library to automatically generate swagger2.0 definition for aiohttp endpoints. - `aio-openapi `_ Asynchronous web middleware for aiohttp and serving Rest APIs with OpenAPI v3 specification and with optional PostgreSQL database bindings. - `rororo `_ Implement ``aiohttp.web`` OpenAPI 3 server applications with schema first approach. Others ------ Here is a list of other known libraries that do not belong in the former categories. We cannot vouch for the quality of these libraries, use them at your own risk. Please add your library reference here first and after some time ask to raise the status. - `pytest-aiohttp-client `_ Pytest fixture with simpler api, payload decoding and status code assertions. - `python-proxy-headers `_ provides ``aiohttp_proxy`` extension for receiving custom response headers from a proxy server - `octomachinery `_ A framework for developing GitHub Apps and GitHub Actions. - `aiomixcloud `_ Mixcloud API wrapper for Python and Async IO. - `aiohttp-cache `_ A cache system for aiohttp server. - `aiocache `_ Caching for asyncio with multiple backends (framework agnostic) - `gain `_ Web crawling framework based on asyncio for everyone. - `aiohttp-validate `_ Simple library that helps you validate your API endpoints requests/responses with json schema. - `raven-aiohttp `_ An aiohttp transport for raven-python (Sentry client). - `webargs `_ A friendly library for parsing HTTP request arguments, with built-in support for popular web frameworks, including Flask, Django, Bottle, Tornado, Pyramid, webapp2, Falcon, and aiohttp. - `aiohttpretty `_ A simple asyncio compatible httpretty mock using aiohttp. - `aioresponses `_ a helper for mock/fake web requests in python aiohttp package. - `aiohttp-transmute `_ A transmute implementation for aiohttp. - `aiohttp-login `_ Registration and authorization (including social) for aiohttp applications. - `aiohttp_utils `_ Handy utilities for building aiohttp.web applications. - `aiohttpproxy `_ Simple aiohttp HTTP proxy. - `aiohttp_traversal `_ Traversal based router for aiohttp.web. - `aiohttp_autoreload `_ Makes aiohttp server auto-reload on source code change. - `gidgethub `_ An async GitHub API library for Python. - `aiohttp-rpc `_ A simple JSON-RPC for aiohttp. - `aiohttp_jrpc `_ aiohttp JSON-RPC service. - `fbemissary `_ A bot framework for the Facebook Messenger platform, built on asyncio and aiohttp. - `aioslacker `_ slacker wrapper for asyncio. - `aioreloader `_ Port of tornado reloader to asyncio. - `aiohttp_babel `_ Babel localization support for aiohttp. - `python-mocket `_ a socket mock framework - for all kinds of socket animals, web-clients included. - `aioraft `_ asyncio RAFT algorithm based on aiohttp. - `home-assistant `_ Open-source home automation platform running on Python 3. - `discord.py `_ Discord client library. - `aiogram `_ A fully asynchronous library for Telegram Bot API written with asyncio and aiohttp. - `aiohttp-graphql `_ GraphQL and GraphIQL interface for aiohttp. - `aiohttp-sentry `_ An aiohttp middleware for reporting errors to Sentry. - `aiohttp-datadog `_ An aiohttp middleware for reporting metrics to DataDog. - `async-v20 `_ Asynchronous FOREX client for OANDA's v20 API. - `aiohttp-jwt `_ An aiohttp middleware for JWT(JSON Web Token) support. - `AWS Xray Python SDK `_ Native tracing support for Aiohttp applications. - `GINO `_ An asyncio ORM on top of SQLAlchemy core, delivered with an aiohttp extension. - `New Relic `_ An aiohttp middleware for reporting your `Python application performance `_ metrics to New Relic. - `eider-py `_ Python implementation of the `Eider RPC protocol `_. - `asynapplicationinsights `_ A client for `Azure Application Insights `_ implemented using ``aiohttp`` client, including a middleware for ``aiohttp`` servers to collect web apps telemetry. - `aiogmaps `_ Asynchronous client for Google Maps API Web Services. - `DBGR `_ Terminal based tool to test and debug HTTP APIs with ``aiohttp``. - `aiohttp-middlewares `_ Collection of useful middlewares for ``aiohttp.web`` applications. - `aiohttp-tus `_ `tus.io `_ protocol implementation for ``aiohttp.web`` applications. - `aiohttp-sse-client `_ A Server-Sent Event python client base on aiohttp. - `aiohttp-retry `_ Wrapper for aiohttp client for retrying requests. - `aiohttp-socks `_ SOCKS proxy connector for aiohttp. - `aiohttp-catcher `_ An aiohttp middleware library for centralized error handling in aiohttp servers. - `rsocket `_ Python implementation of `RSocket protocol `_. - `nacl_middleware `_ An aiohttp middleware library for asymmetric encryption of data transmitted via http and/or websocket connections. - `aiohttp-asgi-connector `_ An aiohttp connector for using a ``ClientSession`` to interface directly with separate ASGI applications. - `aiohttp-openmetrics `_ An aiohttp middleware for exposing Prometheus metrics. - `wireup `_ Performant, concise, and easy-to-use dependency injection container. ================================================ FILE: docs/tracing_reference.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-client-tracing-reference: Tracing Reference ================= .. versionadded:: 3.0 A reference for client tracing API. .. seealso:: :ref:`aiohttp-client-tracing` for tracing usage instructions. Request life cycle ------------------ A request goes through the following stages and corresponding fallbacks. Overview ^^^^^^^^ .. graphviz:: digraph { start[shape=point, xlabel="start", width="0.1"]; redirect[shape=box]; end[shape=point, xlabel="end ", width="0.1"]; exception[shape=oval]; acquire_connection[shape=box]; headers_received[shape=box]; headers_sent[shape=box]; chunk_sent[shape=box]; chunk_received[shape=box]; start -> acquire_connection; acquire_connection -> headers_sent; headers_sent -> headers_received; headers_sent -> chunk_sent; chunk_sent -> chunk_sent; chunk_sent -> headers_received; headers_received -> chunk_received; chunk_received -> chunk_received; chunk_received -> end; headers_received -> redirect; headers_received -> end; redirect -> headers_sent; chunk_received -> exception; chunk_sent -> exception; headers_sent -> exception; } .. list-table:: :header-rows: 1 * - Name - Description * - start - on_request_start * - redirect - on_request_redirect * - acquire_connection - Connection acquiring * - headers_received - * - exception - on_request_exception * - end - on_request_end * - headers_sent - on_request_headers_sent * - chunk_sent - on_request_chunk_sent * - chunk_received - on_response_chunk_received Connection acquiring ^^^^^^^^^^^^^^^^^^^^ .. graphviz:: digraph { begin[shape=point, xlabel="begin", width="0.1"]; end[shape=point, xlabel="end ", width="0.1"]; exception[shape=oval]; queued_start[shape=box]; queued_end[shape=box]; create_start[shape=box]; create_end[shape=box]; reuseconn[shape=box]; resolve_dns[shape=box]; sock_connect[shape=box]; begin -> reuseconn; begin -> create_start; create_start -> resolve_dns; resolve_dns -> exception; resolve_dns -> sock_connect; sock_connect -> exception; sock_connect -> create_end -> end; begin -> queued_start; queued_start -> queued_end; queued_end -> reuseconn; queued_end -> create_start; reuseconn -> end; } .. list-table:: :header-rows: 1 * - Name - Description * - begin - * - end - * - queued_start - on_connection_queued_start * - create_start - on_connection_create_start * - reuseconn - on_connection_reuseconn * - queued_end - on_connection_queued_end * - create_end - on_connection_create_end * - exception - Exception raised * - resolve_dns - DNS resolving * - sock_connect - Connection establishment DNS resolving ^^^^^^^^^^^^^ .. graphviz:: digraph { begin[shape=point, xlabel="begin", width="0.1"]; end[shape=point, xlabel="end", width="0.1"]; exception[shape=oval]; resolve_start[shape=box]; resolve_end[shape=box]; cache_hit[shape=box]; cache_miss[shape=box]; begin -> cache_hit -> end; begin -> cache_miss -> resolve_start; resolve_start -> resolve_end -> end; resolve_start -> exception; } .. list-table:: :header-rows: 1 * - Name - Description * - begin - * - end - * - exception - Exception raised * - resolve_end - on_dns_resolvehost_end * - resolve_start - on_dns_resolvehost_start * - cache_hit - on_dns_cache_hit * - cache_miss - on_dns_cache_miss Classes ------- .. class:: TraceConfig(trace_config_ctx_factory=SimpleNamespace) :canonical: aiohttp.tracing.TraceConfig Trace config is the configuration object used to trace requests launched by a :class:`ClientSession` object using different events related to different parts of the request flow. :param trace_config_ctx_factory: factory used to create trace contexts, default class used :class:`types.SimpleNamespace` .. method:: trace_config_ctx(trace_request_ctx=None) :param trace_request_ctx: Will be used to pass as a kw for the ``trace_config_ctx_factory``. Build a new trace context from the config. Every signal handler should have the following signature:: async def on_signal(session, context, params): ... where ``session`` is :class:`ClientSession` instance, ``context`` is an object returned by :meth:`trace_config_ctx` call and ``params`` is a data class with signal parameters. The type of ``params`` depends on subscribed signal and described below. .. attribute:: on_request_start Property that gives access to the signals that will be executed when a request starts. ``params`` is :class:`aiohttp.TraceRequestStartParams` instance. .. attribute:: on_request_chunk_sent Property that gives access to the signals that will be executed when a chunk of request body is sent. ``params`` is :class:`aiohttp.TraceRequestChunkSentParams` instance. .. versionadded:: 3.1 .. attribute:: on_response_chunk_received Property that gives access to the signals that will be executed when a chunk of response body is received. ``params`` is :class:`aiohttp.TraceResponseChunkReceivedParams` instance. .. versionadded:: 3.1 .. attribute:: on_request_redirect Property that gives access to the signals that will be executed when a redirect happens during a request flow. ``params`` is :class:`aiohttp.TraceRequestRedirectParams` instance. .. attribute:: on_request_end Property that gives access to the signals that will be executed when a request ends. ``params`` is :class:`aiohttp.TraceRequestEndParams` instance. .. attribute:: on_request_exception Property that gives access to the signals that will be executed when a request finishes with an exception. ``params`` is :class:`aiohttp.TraceRequestExceptionParams` instance. .. attribute:: on_connection_queued_start Property that gives access to the signals that will be executed when a request has been queued waiting for an available connection. ``params`` is :class:`aiohttp.TraceConnectionQueuedStartParams` instance. .. attribute:: on_connection_queued_end Property that gives access to the signals that will be executed when a request that was queued already has an available connection. ``params`` is :class:`aiohttp.TraceConnectionQueuedEndParams` instance. .. attribute:: on_connection_create_start Property that gives access to the signals that will be executed when a request creates a new connection. ``params`` is :class:`aiohttp.TraceConnectionCreateStartParams` instance. .. attribute:: on_connection_create_end Property that gives access to the signals that will be executed when a request that created a new connection finishes its creation. ``params`` is :class:`aiohttp.TraceConnectionCreateEndParams` instance. .. attribute:: on_connection_reuseconn Property that gives access to the signals that will be executed when a request reuses a connection. ``params`` is :class:`aiohttp.TraceConnectionReuseconnParams` instance. .. attribute:: on_dns_resolvehost_start Property that gives access to the signals that will be executed when a request starts to resolve the domain related with the request. ``params`` is :class:`aiohttp.TraceDnsResolveHostStartParams` instance. .. attribute:: on_dns_resolvehost_end Property that gives access to the signals that will be executed when a request finishes to resolve the domain related with the request. ``params`` is :class:`aiohttp.TraceDnsResolveHostEndParams` instance. .. attribute:: on_dns_cache_hit Property that gives access to the signals that will be executed when a request was able to use a cached DNS resolution for the domain related with the request. ``params`` is :class:`aiohttp.TraceDnsCacheHitParams` instance. .. attribute:: on_dns_cache_miss Property that gives access to the signals that will be executed when a request was not able to use a cached DNS resolution for the domain related with the request. ``params`` is :class:`aiohttp.TraceDnsCacheMissParams` instance. .. attribute:: on_request_headers_sent Property that gives access to the signals that will be executed when request headers are sent. ``params`` is :class:`aiohttp.TraceRequestHeadersSentParams` instance. .. versionadded:: 3.8 .. class:: TraceRequestStartParams :canonical: aiohttp.tracing.TraceRequestStartParams See :attr:`TraceConfig.on_request_start` for details. .. attribute:: method Method that will be used to make the request. .. attribute:: url URL that will be used for the request. .. attribute:: headers Headers that will be used for the request, can be mutated. .. class:: TraceRequestChunkSentParams :canonical: aiohttp.tracing.TraceRequestChunkSentParams .. versionadded:: 3.1 See :attr:`TraceConfig.on_request_chunk_sent` for details. .. attribute:: method Method that will be used to make the request. .. attribute:: url URL that will be used for the request. .. attribute:: chunk Bytes of chunk sent .. class:: TraceResponseChunkReceivedParams :canonical: aiohttp.tracing.TraceResponseChunkReceivedParams .. versionadded:: 3.1 See :attr:`TraceConfig.on_response_chunk_received` for details. .. attribute:: method Method that will be used to make the request. .. attribute:: url URL that will be used for the request. .. attribute:: chunk Bytes of chunk received .. class:: TraceRequestEndParams :canonical: aiohttp.tracing.TraceRequestEndParams See :attr:`TraceConfig.on_request_end` for details. .. attribute:: method Method used to make the request. .. attribute:: url URL used for the request. .. attribute:: headers Headers used for the request. .. attribute:: response Response :class:`ClientResponse`. .. class:: TraceRequestExceptionParams :canonical: aiohttp.tracing.TraceRequestExceptionParams See :attr:`TraceConfig.on_request_exception` for details. .. attribute:: method Method used to make the request. .. attribute:: url URL used for the request. .. attribute:: headers Headers used for the request. .. attribute:: exception Exception raised during the request. .. class:: TraceRequestRedirectParams :canonical: aiohttp.tracing.TraceRequestRedirectParams See :attr:`TraceConfig.on_request_redirect` for details. .. attribute:: method Method used to get this redirect request. .. attribute:: url URL used for this redirect request. .. attribute:: headers Headers used for this redirect. .. attribute:: response Response :class:`ClientResponse` got from the redirect. .. class:: TraceConnectionQueuedStartParams :canonical: aiohttp.tracing.TraceConnectionQueuedStartParams See :attr:`TraceConfig.on_connection_queued_start` for details. There are no attributes right now. .. class:: TraceConnectionQueuedEndParams :canonical: aiohttp.tracing.TraceConnectionQueuedEndParams See :attr:`TraceConfig.on_connection_queued_end` for details. There are no attributes right now. .. class:: TraceConnectionCreateStartParams :canonical: aiohttp.tracing.TraceConnectionCreateStartParams See :attr:`TraceConfig.on_connection_create_start` for details. There are no attributes right now. .. class:: TraceConnectionCreateEndParams :canonical: aiohttp.tracing.TraceConnectionCreateEndParams See :attr:`TraceConfig.on_connection_create_end` for details. There are no attributes right now. .. class:: TraceConnectionReuseconnParams :canonical: aiohttp.tracing.TraceConnectionReuseconnParams See :attr:`TraceConfig.on_connection_reuseconn` for details. There are no attributes right now. .. class:: TraceDnsResolveHostStartParams :canonical: aiohttp.tracing.TraceDnsResolveHostStartParams See :attr:`TraceConfig.on_dns_resolvehost_start` for details. .. attribute:: host Host that will be resolved. .. class:: TraceDnsResolveHostEndParams :canonical: aiohttp.tracing.TraceDnsResolveHostEndParams See :attr:`TraceConfig.on_dns_resolvehost_end` for details. .. attribute:: host Host that has been resolved. .. class:: TraceDnsCacheHitParams :canonical: aiohttp.tracing.TraceDnsCacheHitParams See :attr:`TraceConfig.on_dns_cache_hit` for details. .. attribute:: host Host found in the cache. .. class:: TraceDnsCacheMissParams :canonical: aiohttp.tracing.TraceDnsCacheMissParams See :attr:`TraceConfig.on_dns_cache_miss` for details. .. attribute:: host Host didn't find the cache. .. class:: TraceRequestHeadersSentParams :canonical: aiohttp.tracing.TraceRequestHeadersSentParams See :attr:`TraceConfig.on_request_headers_sent` for details. .. versionadded:: 3.8 .. attribute:: method Method that will be used to make the request. .. attribute:: url URL that will be used for the request. .. attribute:: headers Headers that will be used for the request. ================================================ FILE: docs/utilities.rst ================================================ .. currentmodule:: aiohttp .. _aiohttp-utilities: Utilities ========= Miscellaneous API Shared between Client And Server. .. toctree:: :name: utilities :maxdepth: 2 abc multipart multipart_reference streams structures websocket_utilities ================================================ FILE: docs/web.rst ================================================ .. _aiohttp-web: Server ====== .. module:: aiohttp.web The page contains all information about aiohttp Server API: .. toctree:: :name: server :maxdepth: 3 Tutorial Quickstart Advanced Usage Low Level Reference Web Exceptions Logging Testing Deployment ================================================ FILE: docs/web_advanced.rst ================================================ .. currentmodule:: aiohttp.web .. _aiohttp-web-advanced: Web Server Advanced =================== Unicode support --------------- *aiohttp* does :term:`requoting` of incoming request path. Unicode (non-ASCII) symbols are processed transparently on both *route adding* and *resolving* (internally everything is converted to :term:`percent-encoding` form by :term:`yarl` library). But in case of custom regular expressions for :ref:`aiohttp-web-variable-handler` please take care that URL is *percent encoded*: if you pass Unicode patterns they don't match to *requoted* path. .. _aiohttp-web-peer-disconnection: Peer disconnection ------------------ *aiohttp* has 2 approaches to handling client disconnections. If you are familiar with asyncio, or scalability is a concern for your application, we recommend using the handler cancellation method. Raise on read/write (default) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ When a client peer is gone, a subsequent reading or writing raises :exc:`OSError` or a more specific exception like :exc:`ConnectionResetError`. This behavior is similar to classic WSGI frameworks like Flask and Django. The reason for disconnection varies; it can be a network issue or explicit socket closing on the peer side without reading the full server response. *aiohttp* handles disconnection properly but you can handle it explicitly, e.g.:: async def handler(request): try: text = await request.text() except OSError: # disconnected .. _web-handler-cancellation: Web handler cancellation ^^^^^^^^^^^^^^^^^^^^^^^^ This method can be enabled using the ``handler_cancellation`` parameter to :func:`run_app`. When a client disconnects, the web handler task will be cancelled. This is recommended as it can reduce the load on your server when there is no client to receive a response. It can also help make your application more resilient to DoS attacks (by requiring an attacker to keep a connection open in order to waste server resources). This behavior is very different from classic WSGI frameworks like Flask and Django. It requires a reasonable level of asyncio knowledge to use correctly without causing issues in your code. We provide some examples here to help understand the complexity and methods needed to deal with them. .. warning:: :term:`web-handler` execution could be canceled on every ``await`` or ``async with`` if client drops connection without reading entire response's BODY. Sometimes it is a desirable behavior: on processing ``GET`` request the code might fetch data from a database or other web resource, the fetching is potentially slow. Canceling this fetch is a good idea: the client dropped the connection already, so there is no reason to waste time and resources (memory etc) by getting data from a DB without any chance to send it back to the client. But sometimes the cancellation is bad: on ``POST`` requests very often it is needed to save data to a DB regardless of connection closing. Cancellation prevention could be implemented in several ways: * Applying :func:`aiojobs.aiohttp.shield` to a coroutine that saves data. * Using aiojobs_ or another third party library to run a task in the background. :func:`aiojobs.aiohttp.shield` can work well. The only disadvantage is you need to split the web handler into two async functions: one for the handler itself and another for protected code. .. warning:: We don't recommend using :func:`asyncio.shield` for this because the shielded task cannot be tracked by the application and therefore there is a risk that the task will get cancelled during application shutdown. The function provided by aiojobs_ operates in the same way except the inner task will be tracked by the Scheduler and will get waited on during the cleanup phase. For example the following snippet is not safe:: from aiojobs.aiohttp import shield async def handler(request): await shield(request, write_to_redis(request)) await shield(request, write_to_postgres(request)) return web.Response(text="OK") Cancellation might occur while saving data in REDIS, so the ``write_to_postgres`` function will not be called, potentially leaving your data in an inconsistent state. Instead, you would need to write something like:: async def write_data(request): await write_to_redis(request) await write_to_postgres(request) async def handler(request): await shield(request, write_data(request)) return web.Response(text="OK") Alternatively, if you want to spawn a task without waiting for its completion, you can use aiojobs_ which provides an API for spawning new background jobs. It stores all scheduled activity in internal data structures and can terminate them gracefully:: from aiojobs.aiohttp import setup, spawn async def handler(request): await spawn(request, write_data()) return web.Response() app = web.Application() setup(app) app.router.add_get("/", handler) .. warning:: Don't use :func:`asyncio.create_task` for this. All tasks should be awaited at some point in your code (``aiojobs`` handles this for you), otherwise you will hide legitimate exceptions and result in warnings being emitted. A good case for using :func:`asyncio.create_task` is when you want to run something while you are processing other data, but still want to ensure the task is complete before returning:: async def handler(request): t = asyncio.create_task(get_some_data()) ... # Do some other things, while data is being fetched. data = await t return web.Response(text=data) One more approach would be to use :func:`aiojobs.aiohttp.atomic` decorator to execute the entire handler as a new job. Essentially restoring the default disconnection behavior only for specific handlers:: from aiojobs.aiohttp import atomic @atomic async def handler(request): await write_to_db() return web.Response() app = web.Application() setup(app) app.router.add_post("/", handler) It prevents all of the ``handler`` async function from cancellation, so ``write_to_db`` will never be interrupted. .. _aiojobs: http://aiojobs.readthedocs.io/en/latest/ Passing a coroutine into run_app and Gunicorn --------------------------------------------- :func:`run_app` accepts either application instance or a coroutine for making an application. The coroutine based approach allows to perform async IO before making an app:: async def app_factory(): await pre_init() app = web.Application() app.router.add_get(...) return app web.run_app(app_factory()) Gunicorn worker supports a factory as well. For Gunicorn the factory should accept zero parameters:: async def my_web_app(): app = web.Application() app.router.add_get(...) return app Start gunicorn: .. code-block:: shell $ gunicorn my_app_module:my_web_app --bind localhost:8080 --worker-class aiohttp.GunicornWebWorker .. versionadded:: 3.1 Custom Routing Criteria ----------------------- Sometimes you need to register :ref:`handlers ` on more complex criteria than simply a *HTTP method* and *path* pair. Although :class:`UrlDispatcher` does not support any extra criteria, routing based on custom conditions can be accomplished by implementing a second layer of routing in your application. The following example shows custom routing based on the *HTTP Accept* header:: class AcceptChooser: def __init__(self): self._accepts = {} async def do_route(self, request): for accept in request.headers.getall('ACCEPT', []): acceptor = self._accepts.get(accept) if acceptor is not None: return (await acceptor(request)) raise HTTPNotAcceptable() def reg_acceptor(self, accept, handler): self._accepts[accept] = handler async def handle_json(request): # do json handling async def handle_xml(request): # do xml handling chooser = AcceptChooser() app.add_routes([web.get('/', chooser.do_route)]) chooser.reg_acceptor('application/json', handle_json) chooser.reg_acceptor('application/xml', handle_xml) .. _aiohttp-web-static-file-handling: Static file handling -------------------- The best way to handle static files (images, JavaScripts, CSS files etc.) is using `Reverse Proxy`_ like `nginx`_ or `CDN`_ services. .. _Reverse Proxy: https://en.wikipedia.org/wiki/Reverse_proxy .. _nginx: https://nginx.org/ .. _CDN: https://en.wikipedia.org/wiki/Content_delivery_network But for development it's very convenient to handle static files by aiohttp server itself. To do it just register a new static route by :meth:`RouteTableDef.static` or :func:`static` calls:: app.add_routes([web.static('/prefix', path_to_static_folder)]) routes.static('/prefix', path_to_static_folder) When a directory is accessed within a static route then the server responses to client with ``HTTP/403 Forbidden`` by default. Displaying folder index instead could be enabled with ``show_index`` parameter set to ``True``:: web.static('/prefix', path_to_static_folder, show_index=True) When a symlink that leads outside the static directory is accessed, the server responds to the client with ``HTTP/404 Not Found`` by default. To allow the server to follow symlinks that lead outside the static root, the parameter ``follow_symlinks`` should be set to ``True``:: web.static('/prefix', path_to_static_folder, follow_symlinks=True) .. caution:: Enabling ``follow_symlinks`` can be a security risk, and may lead to a directory transversal attack. You do NOT need this option to follow symlinks which point to somewhere else within the static directory, this option is only used to break out of the security sandbox. Enabling this option is highly discouraged, and only expected to be used for edge cases in a local development setting where remote users do not have access to the server. When you want to enable cache busting, parameter ``append_version`` can be set to ``True`` Cache busting is the process of appending some form of file version hash to the filename of resources like JavaScript and CSS files. The performance advantage of doing this is that we can tell the browser to cache these files indefinitely without worrying about the client not getting the latest version when the file changes:: web.static('/prefix', path_to_static_folder, append_version=True) Template Rendering ------------------ :mod:`aiohttp.web` does not support template rendering out-of-the-box. However, there is a third-party library, :mod:`aiohttp_jinja2`, which is supported by the *aiohttp* authors. Using it is rather simple. First, setup a *jinja2 environment* with a call to :func:`aiohttp_jinja2.setup`:: app = web.Application() aiohttp_jinja2.setup(app, loader=jinja2.FileSystemLoader('/path/to/templates/folder')) After that you may use the template engine in your :ref:`handlers `. The most convenient way is to simply wrap your handlers with the :func:`aiohttp_jinja2.template` decorator:: @aiohttp_jinja2.template('tmpl.jinja2') async def handler(request): return {'name': 'Andrew', 'surname': 'Svetlov'} If you prefer the `Mako`_ template engine, please take a look at the `aiohttp_mako`_ library. .. warning:: :func:`aiohttp_jinja2.template` should be applied **before** :meth:`RouteTableDef.get` decorator and family, e.g. it must be the *first* (most *down* decorator in the chain):: @routes.get('/path') @aiohttp_jinja2.template('tmpl.jinja2') async def handler(request): return {'name': 'Andrew', 'surname': 'Svetlov'} .. _Mako: http://www.makotemplates.org/ .. _aiohttp_mako: https://github.com/aio-libs/aiohttp_mako .. _aiohttp-web-websocket-read-same-task: Reading from the same task in WebSockets ---------------------------------------- Reading from the *WebSocket* (``await ws.receive()``) **must only** be done inside the request handler *task*; however, writing (``ws.send_str(...)``) to the *WebSocket*, closing (``await ws.close()``) and canceling the handler task may be delegated to other tasks. See also :ref:`FAQ section `. :mod:`aiohttp.web` creates an implicit :class:`asyncio.Task` for handling every incoming request. .. note:: While :mod:`aiohttp.web` itself only supports *WebSockets* without downgrading to *LONG-POLLING*, etc., our team supports SockJS_, an aiohttp-based library for implementing SockJS-compatible server code. .. _SockJS: https://github.com/aio-libs/sockjs .. warning:: Parallel reads from websocket are forbidden, there is no possibility to call :meth:`WebSocketResponse.receive` from two tasks. See :ref:`FAQ section ` for instructions how to solve the problem. .. _aiohttp-web-data-sharing: Data Sharing aka No Singletons Please ------------------------------------- :mod:`aiohttp.web` discourages the use of *global variables*, aka *singletons*. Every variable should have its own context that is *not global*. Global variables are generally considered bad practice due to the complexity they add in keeping track of state changes to variables. *aiohttp* does not use globals by design, which will reduce the number of bugs and/or unexpected behaviors for its users. For example, an i18n translated string being written for one request and then being served to another. So, :class:`Application` and :class:`Request` support a :class:`collections.abc.MutableMapping` interface (i.e. they are dict-like objects), allowing them to be used as data stores. .. _aiohttp-web-data-sharing-app-config: Application's config ^^^^^^^^^^^^^^^^^^^^ For storing *global-like* variables, feel free to save them in an :class:`Application` instance:: app['my_private_key'] = data and get it back in the :term:`web-handler`:: async def handler(request): data = request.app['my_private_key'] Rather than using :class:`str` keys, we recommend using :class:`AppKey`. This is required for type safety (e.g. when checking with mypy):: my_private_key = web.AppKey("my_private_key", str) app[my_private_key] = data async def handler(request: web.Request): data = request.app[my_private_key] # reveal_type(data) -> str In case of :ref:`nested applications ` the desired lookup strategy could be the following: 1. Search the key in the current nested application. 2. If the key is not found continue searching in the parent application(s). For this please use :attr:`Request.config_dict` read-only property:: async def handler(request): data = request.config_dict[my_private_key] The app object can be used in this way to reuse a database connection or anything else needed throughout the application. See this reference section for more detail: :ref:`aiohttp-web-app-and-router`. Request's storage ^^^^^^^^^^^^^^^^^ Variables that are only needed for the lifetime of a :class:`Request`, can be stored in a :class:`Request`. Similarly to :class:`Application`, :class:`RequestKey` instances or strings can be used as keys:: my_private_key = web.RequestKey("my_private_key", str) async def handler(request): request[my_private_key] = "data" ... This is mostly useful for :ref:`aiohttp-web-middlewares` and :ref:`aiohttp-web-signals` handlers to store data for further processing by the next handlers in the chain. Response's storage ^^^^^^^^^^^^^^^^^^ :class:`StreamResponse` and :class:`Response` objects also support :class:`collections.abc.MutableMapping` interface. This is useful when you want to share data with signals and middlewares once all the work in the handler is done:: my_metric_key = web.ResponseKey("my_metric_key", int) async def handler(request): [ do all the work ] response[my_metric_key] = 123 return response Naming hint ^^^^^^^^^^^ To avoid clashing with other *aiohttp* users and third-party libraries, please choose a unique key name for storing data. If your code is published on PyPI, then the project name is most likely unique and safe to use as the key. Otherwise, something based on your company name/url would be satisfactory (i.e. ``org.company.app``). .. _aiohttp-web-contextvars: ContextVars support ------------------- Asyncio has :mod:`Context Variables ` as a context-local storage (a generalization of thread-local concept that works with asyncio tasks also). *aiohttp* server supports it in the following way: * A server inherits the current task's context used when creating it. :func:`aiohttp.web.run_app()` runs a task for handling all underlying jobs running the app, but alternatively :ref:`aiohttp-web-app-runners` can be used. * Application initialization / finalization events (:attr:`Application.cleanup_ctx`, :attr:`Application.on_startup` and :attr:`Application.on_shutdown`, :attr:`Application.on_cleanup`) are executed inside the same context. E.g. all context modifications made on application startup are visible on teardown. * On every request handling *aiohttp* creates a context copy. :term:`web-handler` has all variables installed on initialization stage. But the context modification made by a handler or middleware is invisible to another HTTP request handling call. An example of context vars usage:: from contextvars import ContextVar from aiohttp import web VAR = ContextVar('VAR', default='default') async def coro(): return VAR.get() async def handler(request): var = VAR.get() VAR.set('handler') ret = await coro() return web.Response(text='\n'.join([var, ret])) async def on_startup(app): print('on_startup', VAR.get()) VAR.set('on_startup') async def on_cleanup(app): print('on_cleanup', VAR.get()) VAR.set('on_cleanup') async def init(): print('init', VAR.get()) VAR.set('init') app = web.Application() app.router.add_get('/', handler) app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) return app web.run_app(init()) print('done', VAR.get()) .. versionadded:: 3.5 .. _aiohttp-web-middlewares: Middlewares ----------- :mod:`aiohttp.web` provides a powerful mechanism for customizing :ref:`request handlers` via *middlewares*. A *middleware* is a coroutine that can modify either the request or response. For example, here's a simple *middleware* which appends ``' wink'`` to the response:: from aiohttp import web from typing import Callable, Awaitable async def middleware( request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] ) -> web.StreamResponse: resp = await handler(request) resp.text = resp.text + ' wink' return resp .. warning:: As of version ``4.0.0`` "new-style" middleware is default and the ``@middleware`` decorator is not required (and is deprecated), you can simply remove the decorator. "Old-style" middleware (a coroutine which returned a coroutine) is no longer supported. .. note:: The example won't work with streamed responses or websockets Every *middleware* should accept two parameters, a :class:`request ` instance and a *handler*, and return the response or raise an exception. If the exception is not an instance of :exc:`HTTPException` it is converted to ``500`` :exc:`HTTPInternalServerError` after processing the middlewares chain. .. warning:: Second argument should be named *handler* exactly. When creating an :class:`Application`, these *middlewares* are passed to the keyword-only ``middlewares`` parameter:: app = web.Application(middlewares=[middleware_1, middleware_2]) Internally, a single :ref:`request handler ` is constructed by applying the middleware chain to the original handler in reverse order, and is called by the :class:`~aiohttp.web.RequestHandler` as a regular *handler*. Since *middlewares* are themselves coroutines, they may perform extra ``await`` calls when creating a new handler, e.g. call database etc. *Middlewares* usually call the handler, but they may choose to ignore it, e.g. displaying *403 Forbidden page* or raising :exc:`HTTPForbidden` exception if the user does not have permissions to access the underlying resource. They may also render errors raised by the handler, perform some pre- or post-processing like handling *CORS* and so on. The following code demonstrates middlewares execution order:: from aiohttp import web from typing import Callable, Awaitable async def test(request: web.Request) -> web.Response: print('Handler function called') return web.Response(text="Hello") async def middleware1( request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] ) -> web.StreamResponse: print('Middleware 1 called') response = await handler(request) print('Middleware 1 finished') return response async def middleware2( request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] ) -> web.StreamResponse: print('Middleware 2 called') response = await handler(request) print('Middleware 2 finished') return response app = web.Application(middlewares=[middleware1, middleware2]) app.router.add_get('/', test) web.run_app(app) Produced output:: Middleware 1 called Middleware 2 called Handler function called Middleware 2 finished Middleware 1 finished Request Body Stream Consumption ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. warning:: When middleware reads the request body (using :meth:`~aiohttp.web.BaseRequest.read`, :meth:`~aiohttp.web.BaseRequest.text`, :meth:`~aiohttp.web.BaseRequest.json`, or :meth:`~aiohttp.web.BaseRequest.post`), the body stream is consumed. However, these high-level methods cache their result, so subsequent calls from the handler or other middleware will return the same cached value. The important distinction is: - High-level methods (:meth:`~aiohttp.web.BaseRequest.read`, :meth:`~aiohttp.web.BaseRequest.text`, :meth:`~aiohttp.web.BaseRequest.json`, :meth:`~aiohttp.web.BaseRequest.post`) cache their results internally, so they can be called multiple times and will return the same value. - Direct stream access via :attr:`~aiohttp.web.BaseRequest.content` does NOT have this caching behavior. Once you read from ``request.content`` directly (e.g., using ``await request.content.read()``), subsequent reads will return empty bytes. Consider this middleware that logs request bodies:: from aiohttp import web from typing import Callable, Awaitable async def logging_middleware( request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] ) -> web.StreamResponse: # This consumes the request body stream body = await request.text() print(f"Request body: {body}") return await handler(request) async def handler(request: web.Request) -> web.Response: # This will return the same value that was read in the middleware # (i.e., the cached result, not an empty string) body = await request.text() return web.Response(text=f"Received: {body}") In contrast, when accessing the stream directly (not recommended in middleware):: async def stream_middleware( request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] ) -> web.StreamResponse: # Reading directly from the stream - this consumes it! data = await request.content.read() print(f"Stream data: {data}") return await handler(request) async def handler(request: web.Request) -> web.Response: # This will return empty bytes because the stream was already consumed data = await request.content.read() # data will be b'' (empty bytes) # However, high-level methods would still work if called for the first time: # body = await request.text() # This would read from internal cache if available return web.Response(text=f"Received: {data}") When working with raw stream data that needs to be shared between middleware and handlers:: raw_body_key = web.RequestKey("raw_body_key", bytes) async def stream_parsing_middleware( request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] ) -> web.StreamResponse: # Read stream once and store the data raw_data = await request.content.read() request[raw_body_key] = raw_data return await handler(request) async def handler(request: web.Request) -> web.Response: # Access the stored data instead of reading the stream again raw_data = request.get(raw_body_key, b'') return web.Response(body=raw_data) Example ^^^^^^^ A common use of middlewares is to implement custom error pages. The following example will render 404 errors using a JSON response, as might be appropriate a JSON REST service:: from aiohttp import web async def error_middleware(request, handler): try: response = await handler(request) if response.status != 404: return response message = response.message except web.HTTPException as ex: if ex.status != 404: raise message = ex.reason return web.json_response({'error': message}) app = web.Application(middlewares=[error_middleware]) Middleware Factory ^^^^^^^^^^^^^^^^^^ A *middleware factory* is a function that creates a middleware with passed arguments. For example, here's a trivial *middleware factory*:: def middleware_factory(text): async def sample_middleware(request, handler): resp = await handler(request) resp.text = resp.text + text return resp return sample_middleware Note that in contrast to regular middlewares, a middleware factory should return the function, not the value. So when passing a middleware factory to the app you actually need to call it:: app = web.Application(middlewares=[middleware_factory(' wink')]) .. _aiohttp-web-signals: Signals ------- Although :ref:`middlewares ` can customize :ref:`request handlers` before or after a :class:`Response` has been prepared, they can't customize a :class:`Response` **while** it's being prepared. For this :mod:`aiohttp.web` provides *signals*. For example, a middleware can only change HTTP headers for *unprepared* responses (see :meth:`StreamResponse.prepare`), but sometimes we need a hook for changing HTTP headers for streamed responses and WebSockets. This can be accomplished by subscribing to the :attr:`Application.on_response_prepare` signal, which is called after default headers have been computed and directly before headers are sent:: async def on_prepare(request, response): response.headers['My-Header'] = 'value' app.on_response_prepare.append(on_prepare) Additionally, the :attr:`Application.on_startup` and :attr:`Application.on_cleanup` signals can be subscribed to for application component setup and tear down accordingly. The following example will properly initialize and dispose an asyncpg connection engine:: from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine pg_engine = web.AppKey("pg_engine", AsyncEngine) async def create_pg(app): app[pg_engine] = await create_async_engine( "postgresql+asyncpg://postgre:@localhost:5432/postgre" ) async def dispose_pg(app): await app[pg_engine].dispose() app.on_startup.append(create_pg) app.on_cleanup.append(dispose_pg) Signal handlers should not return a value but may modify incoming mutable parameters. Signal handlers will be run sequentially, in order they were added. All handlers must be asynchronous since *aiohttp* 3.0. .. _aiohttp-web-cleanup-ctx: Cleanup Context --------------- Bare :attr:`Application.on_startup` / :attr:`Application.on_cleanup` pair still has a pitfall: signals handlers are independent on each other. E.g. we have ``[create_pg, create_redis]`` in *startup* signal and ``[dispose_pg, dispose_redis]`` in *cleanup*. If, for example, ``create_pg(app)`` call fails ``create_redis(app)`` is not called. But on application cleanup both ``dispose_pg(app)`` and ``dispose_redis(app)`` are still called: *cleanup signal* has no knowledge about startup/cleanup pairs and their execution state. The solution is :attr:`Application.cleanup_ctx` usage:: @contextlib.asynccontextmanager async def pg_engine(app: web.Application): app[pg_engine] = await create_async_engine( "postgresql+asyncpg://postgre:@localhost:5432/postgre" ) yield await app[pg_engine].dispose() app.cleanup_ctx.append(pg_engine) The attribute is a list of *asynchronous generators*, a code *before* ``yield`` is an initialization stage (called on *startup*), a code *after* ``yield`` is executed on *cleanup*. The generator must have only one ``yield``. *aiohttp* guarantees that *cleanup code* is called if and only if *startup code* was successfully finished. .. versionadded:: 3.1 .. _aiohttp-web-nested-applications: Nested applications ------------------- Sub applications are designed for solving the problem of the big monolithic code base. Let's assume we have a project with own business logic and tools like administration panel and debug toolbar. Administration panel is a separate application by its own nature but all toolbar URLs are served by prefix like ``/admin``. Thus we'll create a totally separate application named ``admin`` and connect it to main app with prefix by :meth:`Application.add_subapp`:: admin = web.Application() # setup admin routes, signals and middlewares app.add_subapp('/admin/', admin) Middlewares and signals from ``app`` and ``admin`` are chained. It means that if URL is ``'/admin/something'`` middlewares from ``app`` are applied first and ``admin.middlewares`` are the next in the call chain. The same is going for :attr:`Application.on_response_prepare` signal -- the signal is delivered to both top level ``app`` and ``admin`` if processing URL is routed to ``admin`` sub-application. Common signals like :attr:`Application.on_startup`, :attr:`Application.on_shutdown` and :attr:`Application.on_cleanup` are delivered to all registered sub-applications. The passed parameter is sub-application instance, not top-level application. Third level sub-applications can be nested into second level ones -- there are no limitation for nesting level. Url reversing for sub-applications should generate urls with proper prefix. But for getting URL sub-application's router should be used:: admin = web.Application() admin.add_routes([web.get('/resource', handler, name='name')]) app.add_subapp('/admin/', admin) url = admin.router['name'].url_for() The generated ``url`` from example will have a value ``URL('/admin/resource')``. If main application should do URL reversing for sub-application it could use the following explicit technique:: admin = web.Application() admin_key = web.AppKey('admin_key', web.Application) admin.add_routes([web.get('/resource', handler, name='name')]) app.add_subapp('/admin/', admin) app[admin_key] = admin async def handler(request: web.Request): # main application's handler admin = request.app[admin_key] url = admin.router['name'].url_for() .. _aiohttp-web-expect-header: *Expect* Header --------------- :mod:`aiohttp.web` supports *Expect* header. By default it sends ``HTTP/1.1 100 Continue`` line to client, or raises :exc:`HTTPExpectationFailed` if header value is not equal to "100-continue". It is possible to specify custom *Expect* header handler on per route basis. This handler gets called if *Expect* header exist in request after receiving all headers and before processing application's :ref:`aiohttp-web-middlewares` and route handler. Handler can return *None*, in that case the request processing continues as usual. If handler returns an instance of class :class:`StreamResponse`, *request handler* uses it as response. Also handler can raise a subclass of :exc:`HTTPException`. In this case all further processing will not happen and client will receive appropriate http response. .. note:: A server that does not understand or is unable to comply with any of the expectation values in the Expect field of a request MUST respond with appropriate error status. The server MUST respond with a 417 (Expectation Failed) status if any of the expectations cannot be met or, if there are other problems with the request, some other 4xx status. http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.20 If all checks pass, the custom handler *must* write a *HTTP/1.1 100 Continue* status code before returning. The following example shows how to setup a custom handler for the *Expect* header:: async def check_auth(request): if request.version != aiohttp.HttpVersion11: return if request.headers.get('EXPECT') != '100-continue': raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect) if request.headers.get('AUTHORIZATION') is None: raise HTTPForbidden() request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") async def hello(request): return web.Response(body=b"Hello, world") app = web.Application() app.add_routes([web.add_get('/', hello, expect_handler=check_auth)]) .. _aiohttp-web-custom-resource: Custom resource implementation ------------------------------ To register custom resource use :meth:`~aiohttp.web.UrlDispatcher.register_resource`. Resource instance must implement `AbstractResource` interface. .. _aiohttp-web-app-runners: Application runners ------------------- :func:`run_app` provides a simple *blocking* API for running an :class:`Application`. For starting the application *asynchronously* or serving on multiple HOST/PORT :class:`AppRunner` exists. The simple startup code for serving HTTP site on ``'localhost'``, port ``8080`` looks like:: runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, 'localhost', 8080) await site.start() while True: await asyncio.sleep(3600) # sleep forever To stop serving call :meth:`AppRunner.cleanup`:: await runner.cleanup() .. versionadded:: 3.0 .. _aiohttp-web-graceful-shutdown: Graceful shutdown ------------------ Stopping *aiohttp web server* by just closing all connections is not always satisfactory. When aiohttp is run with :func:`run_app`, it will attempt a graceful shutdown by following these steps (if using a :ref:`runner `, then calling :meth:`AppRunner.cleanup` will perform these steps, excluding step 7). 1. Stop each site listening on sockets, so new connections will be rejected. 2. Close idle keep-alive connections (and set active ones to close upon completion). 3. Call the :attr:`Application.on_shutdown` signal. This should be used to shutdown long-lived connections, such as websockets (see below). 4. Wait a short time for running handlers to complete. This allows any pending handlers to complete successfully. The timeout can be adjusted with ``shutdown_timeout`` in :func:`run_app`. 5. Close any remaining connections and cancel their handlers. It will wait on the canceling handlers for a short time, again adjustable with ``shutdown_timeout``. 6. Call the :attr:`Application.on_cleanup` signal. This should be used to cleanup any resources (such as DB connections). This includes completing the :ref:`cleanup contexts` which may be used to ensure background tasks are completed successfully (see :ref:`handler cancellation` or aiojobs_ for examples). 7. Cancel any remaining tasks and wait on them to complete. Websocket shutdown ^^^^^^^^^^^^^^^^^^ One problem is if the application supports :term:`websockets ` or *data streaming* it most likely has open connections at server shutdown time. The *library* has no knowledge how to close them gracefully but a developer can help by registering an :attr:`Application.on_shutdown` signal handler. A developer should keep a list of opened connections (:class:`Application` is a good candidate). The following :term:`websocket` snippet shows an example of a websocket handler:: from aiohttp import web import weakref app = web.Application() websockets = web.AppKey("websockets", weakref.WeakSet) app[websockets] = weakref.WeakSet() async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) request.app[websockets].add(ws) try: async for msg in ws: ... finally: request.app[websockets].discard(ws) return ws Then the signal handler may look like:: from aiohttp import WSCloseCode async def on_shutdown(app): for ws in set(app[websockets]): await ws.close(code=WSCloseCode.GOING_AWAY, message="Server shutdown") app.on_shutdown.append(on_shutdown) .. _aiohttp-web-ceil-absolute-timeout: Ceil of absolute timeout value ------------------------------ *aiohttp* **ceils** internal timeout values if the value is equal or greater than 5 seconds. The timeout expires at the next integer second greater than ``current_time + timeout``. More details about ceiling absolute timeout values is available here :ref:`aiohttp-client-timeouts`. The default threshold can be configured at :class:`aiohttp.web.Application` level using the ``handler_args`` parameter. .. code-block:: python3 app = web.Application(handler_args={"timeout_ceil_threshold": 1}) .. _aiohttp-web-background-tasks: Background tasks ----------------- Sometimes there's a need to perform some asynchronous operations just after application start-up. Even more, in some sophisticated systems there could be a need to run some background tasks in the event loop along with the application's request handler. Such as listening to message queue or other network message/event sources (e.g. ZeroMQ, Redis Pub/Sub, AMQP, etc.) to react to received messages within the application. For example the background task could listen to ZeroMQ on ``zmq.SUB`` socket, process and forward retrieved messages to clients connected via WebSocket that are stored somewhere in the application (e.g. in the ``application['websockets']`` list). To run such short and long running background tasks aiohttp provides an ability to register :attr:`Application.on_startup` signal handler(s) that will run along with the application's request handler. For example there's a need to run one quick task and two long running tasks that will live till the application is alive. The appropriate background tasks could be registered as an :attr:`Application.on_startup` signal handler or :attr:`Application.cleanup_ctx` as shown in the example below:: async def listen_to_redis(app: web.Application): client = redis.from_url("redis://localhost:6379") channel = "news" async with client.pubsub() as pubsub: await pubsub.subscribe(channel) while True: msg = await pubsub.get_message(ignore_subscribe_messages=True) if msg is not None: for ws in app["websockets"]: await ws.send_str("{}: {}".format(channel, msg)) @contextlib.asynccontextmanager async def background_tasks(app): app[redis_listener] = asyncio.create_task(listen_to_redis(app)) yield app[redis_listener].cancel() with contextlib.suppress(asyncio.CancelledError): await app[redis_listener] app = web.Application() redis_listener = web.AppKey("redis_listener", asyncio.Task[None]) app.cleanup_ctx.append(background_tasks) web.run_app(app) The task ``listen_to_redis`` will run forever. To shut it down correctly :attr:`Application.on_cleanup` signal handler may be used to send a cancellation to it. .. _aiohttp-web-complex-applications: Complex Applications ^^^^^^^^^^^^^^^^^^^^ Sometimes aiohttp is not the sole part of an application and additional tasks/processes may need to be run alongside the aiohttp :class:`Application`. Generally, the best way to achieve this is to use :func:`aiohttp.web.run_app` as the entry point for the program. Other tasks can then be run via :attr:`Application.startup` and :attr:`Application.on_cleanup`. By having the :class:`Application` control the lifecycle of the entire program, the code will be more robust and ensure that the tasks are started and stopped along with the application. For example, running a long-lived task alongside the :class:`Application` can be done with a :ref:`aiohttp-web-cleanup-ctx` function like:: @contextlib.asynccontextmanager async def run_other_task(_app): task = asyncio.create_task(other_long_task()) yield task.cancel() with suppress(asyncio.CancelledError): await task # Ensure any exceptions etc. are raised. app.cleanup_ctx.append(run_other_task) Or a separate process can be run with something like:: @contextlib.asynccontextmanager async def run_process(_app): proc = await asyncio.create_subprocess_exec(path) yield if proc.returncode is None: proc.terminate() await proc.wait() app.cleanup_ctx.append(run_process) Handling error pages -------------------- Pages like *404 Not Found* and *500 Internal Error* could be handled by custom middleware, see :ref:`polls demo ` for example. .. _aiohttp-web-forwarded-support: Deploying behind a Proxy ------------------------ As discussed in :ref:`aiohttp-deployment` the preferable way is deploying *aiohttp* web server behind a *Reverse Proxy Server* like :term:`nginx` for production usage. In this way properties like :attr:`BaseRequest.scheme` :attr:`BaseRequest.host` and :attr:`BaseRequest.remote` are incorrect. Real values should be given from proxy server, usually either ``Forwarded`` or old-fashion ``X-Forwarded-For``, ``X-Forwarded-Host``, ``X-Forwarded-Proto`` HTTP headers are used. *aiohttp* does not take *forwarded* headers into account by default because it produces *security issue*: HTTP client might add these headers too, pushing non-trusted data values. That's why *aiohttp server* should setup *forwarded* headers in custom middleware in tight conjunction with *reverse proxy configuration*. For changing :attr:`BaseRequest.scheme` :attr:`BaseRequest.host` :attr:`BaseRequest.remote` and :attr:`BaseRequest.client_max_size` the middleware might use :meth:`BaseRequest.clone`. .. seealso:: https://github.com/aio-libs/aiohttp-remotes provides secure helpers for modifying *scheme*, *host* and *remote* attributes according to ``Forwarded`` and ``X-Forwarded-*`` HTTP headers. CORS support ------------ :mod:`aiohttp.web` itself does not support `Cross-Origin Resource Sharing `_, but there is an aiohttp plugin for it: `aiohttp-cors `_. Debug Toolbar ------------- `aiohttp-debugtoolbar`_ is a very useful library that provides a debugging toolbar while you're developing an :mod:`aiohttp.web` application. Install it with ``pip``: .. code-block:: shell $ pip install aiohttp_debugtoolbar Just call :func:`aiohttp_debugtoolbar.setup`:: import aiohttp_debugtoolbar from aiohttp_debugtoolbar import toolbar_middleware_factory app = web.Application() aiohttp_debugtoolbar.setup(app) The toolbar is ready to use. Enjoy!!! .. _aiohttp-debugtoolbar: https://github.com/aio-libs/aiohttp_debugtoolbar Dev Tools --------- `aiohttp-devtools`_ provides a couple of tools to simplify development of :mod:`aiohttp.web` applications. Install with ``pip``: .. code-block:: shell $ pip install aiohttp-devtools ``adev runserver`` provides a development server with auto-reload, live-reload, static file serving. Documentation and a complete tutorial of creating and running an app locally are available at `aiohttp-devtools`_. .. _aiohttp-devtools: https://github.com/aio-libs/aiohttp-devtools ================================================ FILE: docs/web_exceptions.rst ================================================ .. currentmodule:: aiohttp.web .. _aiohttp-web-exceptions: Web Server Exceptions ===================== Overview -------- :mod:`aiohttp.web` defines a set of exceptions for every *HTTP status code*. Each exception is a subclass of :exc:`HTTPException` and relates to a single HTTP status code:: async def handler(request): raise aiohttp.web.HTTPFound('/redirect') Each exception class has a status code according to :rfc:`2068`: codes with 100-300 are not really errors; 400s are client errors, and 500s are server errors. HTTP Exception hierarchy chart:: Exception HTTPException HTTPSuccessful * 200 - HTTPOk * 201 - HTTPCreated * 202 - HTTPAccepted * 203 - HTTPNonAuthoritativeInformation * 204 - HTTPNoContent * 205 - HTTPResetContent * 206 - HTTPPartialContent HTTPRedirection * 304 - HTTPNotModified HTTPMove * 300 - HTTPMultipleChoices * 301 - HTTPMovedPermanently * 302 - HTTPFound * 303 - HTTPSeeOther * 305 - HTTPUseProxy * 307 - HTTPTemporaryRedirect * 308 - HTTPPermanentRedirect HTTPError HTTPClientError * 400 - HTTPBadRequest * 401 - HTTPUnauthorized * 402 - HTTPPaymentRequired * 403 - HTTPForbidden * 404 - HTTPNotFound * 405 - HTTPMethodNotAllowed * 406 - HTTPNotAcceptable * 407 - HTTPProxyAuthenticationRequired * 408 - HTTPRequestTimeout * 409 - HTTPConflict * 410 - HTTPGone * 411 - HTTPLengthRequired * 412 - HTTPPreconditionFailed * 413 - HTTPRequestEntityTooLarge * 414 - HTTPRequestURITooLong * 415 - HTTPUnsupportedMediaType * 416 - HTTPRequestRangeNotSatisfiable * 417 - HTTPExpectationFailed * 421 - HTTPMisdirectedRequest * 422 - HTTPUnprocessableEntity * 424 - HTTPFailedDependency * 426 - HTTPUpgradeRequired * 428 - HTTPPreconditionRequired * 429 - HTTPTooManyRequests * 431 - HTTPRequestHeaderFieldsTooLarge * 451 - HTTPUnavailableForLegalReasons HTTPServerError * 500 - HTTPInternalServerError * 501 - HTTPNotImplemented * 502 - HTTPBadGateway * 503 - HTTPServiceUnavailable * 504 - HTTPGatewayTimeout * 505 - HTTPVersionNotSupported * 506 - HTTPVariantAlsoNegotiates * 507 - HTTPInsufficientStorage * 510 - HTTPNotExtended * 511 - HTTPNetworkAuthenticationRequired All HTTP exceptions have the same constructor signature:: HTTPNotFound(*, headers=None, reason=None, text=None, content_type=None) If not directly specified, *headers* will be added to the *default response headers*. Classes :exc:`HTTPMultipleChoices`, :exc:`HTTPMovedPermanently`, :exc:`HTTPFound`, :exc:`HTTPSeeOther`, :exc:`HTTPUseProxy`, :exc:`HTTPTemporaryRedirect` have the following constructor signature:: HTTPFound(location, *,headers=None, reason=None, text=None, content_type=None) where *location* is value for *Location HTTP header*. :exc:`HTTPMethodNotAllowed` is constructed by providing the incoming unsupported method and list of allowed methods:: HTTPMethodNotAllowed(method, allowed_methods, *, headers=None, reason=None, text=None, content_type=None) :exc:`HTTPUnavailableForLegalReasons` should be constructed with a ``link`` to yourself (as the entity implementing the blockage), and an explanation for the block included in ``text``.:: HTTPUnavailableForLegalReasons(link, *, headers=None, reason=None, text=None, content_type=None) Base HTTP Exception ------------------- .. exception:: HTTPException(*, headers=None, reason=None, text=None, \ content_type=None) :canonical: aiohttp.web_exceptions.HTTPException The base class for HTTP server exceptions. Inherited from :exc:`Exception`. :param headers: HTTP headers (:class:`~collections.abc.Mapping`) :param str reason: an optional custom HTTP reason. aiohttp uses *default reason string* if not specified. :param str text: an optional text used in response body. If not specified *default text* is constructed from status code and reason, e.g. `"404: Not Found"`. :param str content_type: an optional Content-Type, `"text/plain"` by default. .. attribute:: status HTTP status code for the exception, :class:`int` .. attribute:: reason HTTP status reason for the exception, :class:`str` .. attribute:: text HTTP status reason for the exception, :class:`str` or ``None`` for HTTP exceptions without body, e.g. "204 No Content" .. attribute:: headers HTTP headers for the exception, :class:`multidict.CIMultiDict` .. attribute:: cookies An instance of :class:`http.cookies.SimpleCookie` for *outgoing* cookies. .. versionadded:: 4.0 .. method:: set_cookie(name, value, *, path='/', expires=None, \ domain=None, max_age=None, \ secure=None, httponly=None, version=None, \ samesite=None) Convenient way for setting :attr:`cookies`, allows to specify some additional properties like *max_age* in a single call. .. versionadded:: 4.0 :param str name: cookie name :param str value: cookie value (will be converted to :class:`str` if value has another type). :param expires: expiration date (optional) :param str domain: cookie domain (optional) :param int max_age: defines the lifetime of the cookie, in seconds. The delta-seconds value is a decimal non- negative integer. After delta-seconds seconds elapse, the client should discard the cookie. A value of zero means the cookie should be discarded immediately. (optional) :param str path: specifies the subset of URLs to which this cookie applies. (optional, ``'/'`` by default) :param bool secure: attribute (with no value) directs the user agent to use only (unspecified) secure means to contact the origin server whenever it sends back this cookie. The user agent (possibly under the user's control) may determine what level of security it considers appropriate for "secure" cookies. The *secure* should be considered security advice from the server to the user agent, indicating that it is in the session's interest to protect the cookie contents. (optional) :param bool httponly: ``True`` if the cookie HTTP only (optional) :param int version: a decimal integer, identifies to which version of the state management specification the cookie conforms. (Optional, *version=1* by default) :param str samesite: Asserts that a cookie must not be sent with cross-origin requests, providing some protection against cross-site request forgery attacks. Generally the value should be one of: ``None``, ``Lax`` or ``Strict``. (optional) .. warning:: In HTTP version 1.1, ``expires`` was deprecated and replaced with the easier-to-use ``max-age``, but Internet Explorer (IE6, IE7, and IE8) **does not** support ``max-age``. .. method:: del_cookie(name, *, path='/', domain=None) Deletes cookie. .. versionadded:: 4.0 :param str name: cookie name :param str domain: optional cookie domain :param str path: optional cookie path, ``'/'`` by default Successful Exceptions --------------------- HTTP exceptions for status code in range 200-299. They are not *errors* but special classes reflected in exceptions hierarchy. E.g. ``raise web.HTTPNoContent`` may look strange a little but the construction is absolutely legal. .. exception:: HTTPSuccessful :canonical: aiohttp.web_exceptions.HTTPSuccessful A base class for the category, a subclass of :exc:`HTTPException`. .. exception:: HTTPOk :canonical: aiohttp.web_exceptions.HTTPOk An exception for *200 OK*, a subclass of :exc:`HTTPSuccessful`. .. exception:: HTTPCreated :canonical: aiohttp.web_exceptions.HTTPCreated An exception for *201 Created*, a subclass of :exc:`HTTPSuccessful`. .. exception:: HTTPAccepted :canonical: aiohttp.web_exceptions.HTTPAccepted An exception for *202 Accepted*, a subclass of :exc:`HTTPSuccessful`. .. exception:: HTTPNonAuthoritativeInformation :canonical: aiohttp.web_exceptions.HTTPNonAuthoritativeInformation An exception for *203 Non-Authoritative Information*, a subclass of :exc:`HTTPSuccessful`. .. exception:: HTTPNoContent :canonical: aiohttp.web_exceptions.HTTPNoContent An exception for *204 No Content*, a subclass of :exc:`HTTPSuccessful`. Has no HTTP body. .. exception:: HTTPResetContent :canonical: aiohttp.web_exceptions.HTTPResetContent An exception for *205 Reset Content*, a subclass of :exc:`HTTPSuccessful`. Has no HTTP body. .. exception:: HTTPPartialContent :canonical: aiohttp.web_exceptions.HTTPPartialContent An exception for *206 Partial Content*, a subclass of :exc:`HTTPSuccessful`. Redirections ------------ HTTP exceptions for status code in range 300-399, e.g. ``raise web.HTTPMovedPermanently(location='/new/path')``. .. exception:: HTTPRedirection :canonical: aiohttp.web_exceptions.HTTPRedirection A base class for the category, a subclass of :exc:`HTTPException`. .. exception:: HTTPMove(location, *, headers=None, reason=None, text=None, \ content_type=None) :canonical: aiohttp.web_exceptions.HTTPMove A base class for redirections with implied *Location* header, all redirections except :exc:`HTTPNotModified`. :param location: a :class:`yarl.URL` or :class:`str` used for *Location* HTTP header. For other arguments see :exc:`HTTPException` constructor. .. attribute:: location A *Location* HTTP header value, :class:`yarl.URL`. .. exception:: HTTPMultipleChoices :canonical: aiohttp.web_exceptions.HTTPMultipleChoices An exception for *300 Multiple Choices*, a subclass of :exc:`HTTPMove`. .. exception:: HTTPMovedPermanently :canonical: aiohttp.web_exceptions.HTTPMovedPermanently An exception for *301 Moved Permanently*, a subclass of :exc:`HTTPMove`. .. exception:: HTTPFound :canonical: aiohttp.web_exceptions.HTTPFound An exception for *302 Found*, a subclass of :exc:`HTTPMove`. .. exception:: HTTPSeeOther :canonical: aiohttp.web_exceptions.HTTPSeeOther An exception for *303 See Other*, a subclass of :exc:`HTTPMove`. .. exception:: HTTPNotModified :canonical: aiohttp.web_exceptions.HTTPNotModified An exception for *304 Not Modified*, a subclass of :exc:`HTTPRedirection`. Has no HTTP body. .. exception:: HTTPUseProxy :canonical: aiohttp.web_exceptions.HTTPUseProxy An exception for *305 Use Proxy*, a subclass of :exc:`HTTPMove`. .. exception:: HTTPTemporaryRedirect :canonical: aiohttp.web_exceptions.HTTPTemporaryRedirect An exception for *307 Temporary Redirect*, a subclass of :exc:`HTTPMove`. .. exception:: HTTPPermanentRedirect :canonical: aiohttp.web_exceptions.HTTPPermanentRedirect An exception for *308 Permanent Redirect*, a subclass of :exc:`HTTPMove`. Client Errors ------------- HTTP exceptions for status code in range 400-499, e.g. ``raise web.HTTPNotFound()``. .. exception:: HTTPClientError :canonical: aiohttp.web_exceptions.HTTPClientError A base class for the category, a subclass of :exc:`HTTPException`. .. exception:: HTTPBadRequest :canonical: aiohttp.web_exceptions.HTTPBadRequest An exception for *400 Bad Request*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPUnauthorized :canonical: aiohttp.web_exceptions.HTTPUnauthorized An exception for *401 Unauthorized*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPPaymentRequired :canonical: aiohttp.web_exceptions.HTTPPaymentRequired An exception for *402 Payment Required*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPForbidden :canonical: aiohttp.web_exceptions.HTTPForbidden An exception for *403 Forbidden*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPNotFound :canonical: aiohttp.web_exceptions.HTTPNotFound An exception for *404 Not Found*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPMethodNotAllowed(method, allowed_methods, *, \ headers=None, reason=None, text=None, \ content_type=None) :canonical: aiohttp.web_exceptions.HTTPMethodNotAllowed An exception for *405 Method Not Allowed*, a subclass of :exc:`HTTPClientError`. :param str method: requested but not allowed HTTP method. :param allowed_methods: an iterable of allowed HTTP methods (:class:`str`), *Allow* HTTP header is constructed from the sequence separated by comma. For other arguments see :exc:`HTTPException` constructor. .. attribute:: allowed_methods A set of allowed HTTP methods. .. attribute:: method Requested but not allowed HTTP method. .. exception:: HTTPNotAcceptable :canonical: aiohttp.web_exceptions.HTTPNotAcceptable An exception for *406 Not Acceptable*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPProxyAuthenticationRequired :canonical: aiohttp.web_exceptions.HTTPProxyAuthenticationRequired An exception for *407 Proxy Authentication Required*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPRequestTimeout :canonical: aiohttp.web_exceptions.HTTPRequestTimeout An exception for *408 Request Timeout*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPConflict :canonical: aiohttp.web_exceptions.HTTPConflict An exception for *409 Conflict*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPGone :canonical: aiohttp.web_exceptions.HTTPGone An exception for *410 Gone*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPLengthRequired :canonical: aiohttp.web_exceptions.HTTPLengthRequired An exception for *411 Length Required*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPPreconditionFailed :canonical: aiohttp.web_exceptions.HTTPPreconditionFailed An exception for *412 Precondition Failed*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPRequestEntityTooLarge(max_size, actual_size, **kwargs) :canonical: aiohttp.web_exceptions.HTTPRequestEntityTooLarge An exception for *413 Entity Too Large*, a subclass of :exc:`HTTPClientError`. :param int max_size: Maximum allowed request body size :param int actual_size: Actual received size For other acceptable parameters see :exc:`HTTPException` constructor. .. exception:: HTTPRequestURITooLong :canonical: aiohttp.web_exceptions.HTTPRequestURITooLong An exception for *414 URI is too long*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPUnsupportedMediaType :canonical: aiohttp.web_exceptions.HTTPUnsupportedMediaType An exception for *415 Entity body in unsupported format*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPRequestRangeNotSatisfiable :canonical: aiohttp.web_exceptions.HTTPRequestRangeNotSatisfiable An exception for *416 Cannot satisfy request range*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPExpectationFailed :canonical: aiohttp.web_exceptions.HTTPExpectationFailed An exception for *417 Expect condition could not be satisfied*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPMisdirectedRequest :canonical: aiohttp.web_exceptions.HTTPMisdirectedRequest An exception for *421 Misdirected Request*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPUnprocessableEntity :canonical: aiohttp.web_exceptions.HTTPUnprocessableEntity An exception for *422 Unprocessable Entity*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPFailedDependency :canonical: aiohttp.web_exceptions.HTTPFailedDependency An exception for *424 Failed Dependency*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPUpgradeRequired :canonical: aiohttp.web_exceptions.HTTPUpgradeRequired An exception for *426 Upgrade Required*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPPreconditionRequired :canonical: aiohttp.web_exceptions.HTTPPreconditionRequired An exception for *428 Precondition Required*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPTooManyRequests :canonical: aiohttp.web_exceptions.HTTPTooManyRequests An exception for *429 Too Many Requests*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPRequestHeaderFieldsTooLarge :canonical: aiohttp.web_exceptions.HTTPRequestHeaderFieldsTooLarge An exception for *431 Requests Header Fields Too Large*, a subclass of :exc:`HTTPClientError`. .. exception:: HTTPUnavailableForLegalReasons(link, *, \ headers=None, \ reason=None, \ text=None, \ content_type=None) :canonical: aiohttp.web_exceptions.HTTPUnavailableForLegalReasons An exception for *451 Unavailable For Legal Reasons*, a subclass of :exc:`HTTPClientError`. :param link: A link to yourself (as the entity implementing the blockage), :class:`str`, :class:`~yarl.URL` or ``None``. For other parameters see :exc:`HTTPException` constructor. A reason for the block should be included in ``text``. .. attribute:: link A :class:`~yarl.URL` link to the entity implementing the blockage or ``None``, read-only property. Server Errors ------------- HTTP exceptions for status code in range 500-599, e.g. ``raise web.HTTPBadGateway()``. .. exception:: HTTPServerError :canonical: aiohttp.web_exceptions.HTTPServerError A base class for the category, a subclass of :exc:`HTTPException`. .. exception:: HTTPInternalServerError :canonical: aiohttp.web_exceptions.HTTPInternalServerError An exception for *500 Server got itself in trouble*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPNotImplemented :canonical: aiohttp.web_exceptions.HTTPNotImplemented An exception for *501 Server does not support this operation*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPBadGateway :canonical: aiohttp.web_exceptions.HTTPBadGateway An exception for *502 Invalid responses from another server/proxy*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPServiceUnavailable :canonical: aiohttp.web_exceptions.HTTPServiceUnavailable An exception for *503 The server cannot process the request due to a high load*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPGatewayTimeout :canonical: aiohttp.web_exceptions.HTTPGatewayTimeout An exception for *504 The gateway server did not receive a timely response*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPVersionNotSupported :canonical: aiohttp.web_exceptions.HTTPVersionNotSupported An exception for *505 Cannot fulfill request*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPVariantAlsoNegotiates :canonical: aiohttp.web_exceptions.HTTPVariantAlsoNegotiates An exception for *506 Variant Also Negotiates*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPInsufficientStorage :canonical: aiohttp.web_exceptions.HTTPInsufficientStorage An exception for *507 Insufficient Storage*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPNotExtended :canonical: aiohttp.web_exceptions.HTTPNotExtended An exception for *510 Not Extended*, a subclass of :exc:`HTTPServerError`. .. exception:: HTTPNetworkAuthenticationRequired :canonical: aiohttp.web_exceptions.HTTPNetworkAuthenticationRequired An exception for *511 Network Authentication Required*, a subclass of :exc:`HTTPServerError`. ================================================ FILE: docs/web_lowlevel.rst ================================================ .. currentmodule:: aiohttp.web .. _aiohttp-web-lowlevel: Low Level Server ================ This topic describes :mod:`aiohttp.web` based *low level* API. Abstract -------- Sometimes users don't need high-level concepts introduced in :ref:`aiohttp-web`: applications, routers, middlewares and signals. All that may be needed is supporting an asynchronous callable which accepts a request and returns a response object. This is done by introducing :class:`aiohttp.web.Server` class which serves a *protocol factory* role for :meth:`asyncio.loop.create_server` and bridges data stream to *web handler* and sends result back. Low level *web handler* should accept the single :class:`BaseRequest` parameter and performs one of the following actions: 1. Return a :class:`Response` with the whole HTTP body stored in memory. 2. Create a :class:`StreamResponse`, send headers by :meth:`StreamResponse.prepare` call, send data chunks by :meth:`StreamResponse.write` and return finished response. 3. Raise :class:`HTTPException` derived exception (see :ref:`aiohttp-web-exceptions` section). All other exceptions not derived from :class:`HTTPException` leads to *500 Internal Server Error* response. 4. Initiate and process Web-Socket connection by :class:`WebSocketResponse` using (see :ref:`aiohttp-web-websockets`). Run a Basic Low-Level Server ---------------------------- The following code demonstrates very trivial usage example:: import asyncio from aiohttp import web async def handler(request): return web.Response(text="OK") async def main(): server = web.Server(handler) runner = web.ServerRunner(server) await runner.setup() site = web.TCPSite(runner, 'localhost', 8080) await site.start() print("======= Serving on http://127.0.0.1:8080/ ======") # pause here for very long time by serving HTTP requests and # waiting for keyboard interruption await asyncio.sleep(100*3600) asyncio.run(main()) In the snippet we have ``handler`` which returns a regular :class:`Response` with ``"OK"`` in BODY. This *handler* is processed by ``server`` (:class:`Server` which acts as *protocol factory*). Network communication is created by :ref:`runners API ` to serve ``http://127.0.0.1:8080/``. The handler should process every request for every *path*, e.g. ``GET``, ``POST``, Web-Socket. The example is very basic: it always return ``200 OK`` response, real life code is much more complex usually. ================================================ FILE: docs/web_quickstart.rst ================================================ .. currentmodule:: aiohttp.web .. _aiohttp-web-quickstart: Web Server Quickstart ===================== Run a Simple Web Server ----------------------- In order to implement a web server, first create a :ref:`request handler `. A request handler must be a :ref:`coroutine ` that accepts a :class:`Request` instance as its only parameter and returns a :class:`Response` instance:: from aiohttp import web async def hello(request): return web.Response(text="Hello, world") Next, create an :class:`Application` instance and register the request handler on a particular *HTTP method* and *path*:: app = web.Application() app.add_routes([web.get('/', hello)]) After that, run the application by :func:`run_app` call:: web.run_app(app) That's it. Now, head over to ``http://localhost:8080/`` to see the results. Alternatively if you prefer *route decorators* create a *route table* and register a :term:`web-handler`:: routes = web.RouteTableDef() @routes.get('/') async def hello(request): return web.Response(text="Hello, world") app = web.Application() app.add_routes(routes) web.run_app(app) Both ways essentially do the same work, the difference is only in your taste: do you prefer *Django style* with famous ``urls.py`` or *Flask* with shiny route decorators. *aiohttp* server documentation uses both ways in code snippets to emphasize their equality, switching from one style to another is very trivial. .. note:: You can get a powerful aiohttp template by running one command. To do this, simply use our `boilerplate for quick start with aiohttp `_. .. seealso:: :ref:`aiohttp-web-graceful-shutdown` section explains what :func:`run_app` does and how to implement complex server initialization/finalization from scratch. :ref:`aiohttp-web-app-runners` for more handling more complex cases like *asynchronous* web application serving and multiple hosts support. .. _aiohttp-web-cli: Command Line Interface (CLI) ---------------------------- :mod:`aiohttp.web` implements a basic CLI for quickly serving an :class:`Application` in *development* over TCP/IP: .. code-block:: shell $ python -m aiohttp.web -H localhost -P 8080 package.module:init_func ``package.module:init_func`` should be an importable :term:`callable` that accepts a list of any non-parsed command-line arguments and returns an :class:`Application` instance after setting it up:: def init_func(argv): app = web.Application() app.router.add_get("/", index_handler) return app .. note:: For local development we typically recommend using `aiohttp-devtools `_. .. _aiohttp-web-handler: Handler ------- A request handler must be a :ref:`coroutine` that accepts a :class:`Request` instance as its only argument and returns a :class:`StreamResponse` derived (e.g. :class:`Response`) instance:: async def handler(request): return web.Response() Handlers are setup to handle requests by registering them with the :meth:`Application.add_routes` on a particular route (*HTTP method* and *path* pair) using helpers like :func:`get` and :func:`post`:: app.add_routes([web.get('/', handler), web.post('/post', post_handler), web.put('/put', put_handler)]) Or use *route decorators*:: routes = web.RouteTableDef() @routes.get('/') async def get_handler(request): ... @routes.post('/post') async def post_handler(request): ... @routes.put('/put') async def put_handler(request): ... app.add_routes(routes) Wildcard *HTTP method* is also supported by :func:`route` or :meth:`RouteTableDef.route`, allowing a handler to serve incoming requests on a *path* having **any** *HTTP method*:: app.add_routes([web.route('*', '/path', all_handler)]) The *HTTP method* can be queried later in the request handler using the :attr:`aiohttp.web.BaseRequest.method` property. By default endpoints added with ``GET`` method will accept ``HEAD`` requests and return the same response headers as they would for a ``GET`` request. You can also deny ``HEAD`` requests on a route:: web.get('/', handler, allow_head=False) Here ``handler`` won't be called on ``HEAD`` request and the server will respond with ``405: Method Not Allowed``. .. seealso:: :ref:`aiohttp-web-peer-disconnection` section explains how handlers behave when a client connection drops and ways to optimize handling of this. .. _aiohttp-web-resource-and-route: Resources and Routes -------------------- Internally routes are served by :attr:`Application.router` (:class:`UrlDispatcher` instance). The *router* is a list of *resources*. Resource is an entry in *route table* which corresponds to requested URL. Resource in turn has at least one *route*. Route corresponds to handling *HTTP method* by calling *web handler*. Thus when you add a *route* the *resource* object is created under the hood. The library implementation **merges** all subsequent route additions for the same path adding the only resource for all HTTP methods. Consider two examples:: app.add_routes([web.get('/path1', get_1), web.post('/path1', post_1), web.get('/path2', get_2), web.post('/path2', post_2)] and:: app.add_routes([web.get('/path1', get_1), web.get('/path2', get_2), web.post('/path2', post_2), web.post('/path1', post_1)] First one is *optimized*. You have got the idea. .. _aiohttp-web-variable-handler: Variable Resources ^^^^^^^^^^^^^^^^^^ Resource may have *variable path* also. For instance, a resource with the path ``'/a/{name}/c'`` would match all incoming requests with paths such as ``'/a/b/c'``, ``'/a/1/c'``, and ``'/a/etc/c'``. A variable *part* is specified in the form ``{identifier}``, where the ``identifier`` can be used later in a :ref:`request handler ` to access the matched value for that *part*. This is done by looking up the ``identifier`` in the :attr:`Request.match_info` mapping:: @routes.get('/{name}') async def variable_handler(request): return web.Response( text="Hello, {}".format(request.match_info['name'])) By default, each *part* matches the regular expression ``[^{}/]+``. You can also specify a custom regex in the form ``{identifier:regex}``:: web.get(r'/{name:\d+}', handler) .. _aiohttp-web-named-routes: Reverse URL Constructing using Named Resources ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Routes can also be given a *name*:: @routes.get('/root', name='root') async def handler(request): ... Which can then be used to access and build a *URL* for that resource later (e.g. in a :ref:`request handler `):: url = request.app.router['root'].url_for().with_query({"a": "b", "c": "d"}) assert url == URL('/root?a=b&c=d') A more interesting example is building *URLs* for :ref:`variable resources `:: app.router.add_resource(r'/{user}/info', name='user-info') In this case you can also pass in the *parts* of the route:: url = request.app.router['user-info'].url_for(user='john_doe') url_with_qs = url.with_query("a=b") assert url_with_qs == '/john_doe/info?a=b' Organizing Handlers in Classes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ As discussed above, :ref:`handlers ` can be first-class coroutines:: async def hello(request): return web.Response(text="Hello, world") app.router.add_get('/', hello) But sometimes it's convenient to group logically similar handlers into a Python *class*. Since :mod:`aiohttp.web` does not dictate any implementation details, application developers can organize handlers in classes if they so wish:: class Handler: def __init__(self): pass async def handle_intro(self, request): return web.Response(text="Hello, world") async def handle_greeting(self, request): name = request.match_info.get('name', "Anonymous") txt = "Hello, {}".format(name) return web.Response(text=txt) handler = Handler() app.add_routes([web.get('/intro', handler.handle_intro), web.get('/greet/{name}', handler.handle_greeting)]) .. _aiohttp-web-class-based-views: Class Based Views ^^^^^^^^^^^^^^^^^ :mod:`aiohttp.web` has support for *class based views*. You can derive from :class:`View` and define methods for handling http requests:: class MyView(web.View): async def get(self): return await get_resp(self.request) async def post(self): return await post_resp(self.request) Handlers should be coroutines accepting *self* only and returning response object as regular :term:`web-handler`. Request object can be retrieved by :attr:`View.request` property. After implementing the view (``MyView`` from example above) should be registered in application's router:: app.add_routes([web.view('/path/to', MyView)]) or:: @routes.view('/path/to') class MyView(web.View): ... or:: app.router.add_route('*', '/path/to', MyView) Example will process GET and POST requests for */path/to* but raise *405 Method not allowed* exception for unimplemented HTTP methods. Resource Views ^^^^^^^^^^^^^^ *All* registered resources in a router can be viewed using the :meth:`UrlDispatcher.resources` method:: for resource in app.router.resources(): print(resource) A *subset* of the resources that were registered with a *name* can be viewed using the :meth:`UrlDispatcher.named_resources` method:: for name, resource in app.router.named_resources().items(): print(name, resource) .. _aiohttp-web-alternative-routes-definition: Alternative ways for registering routes --------------------------------------- Code examples shown above use *imperative* style for adding new routes: they call ``app.router.add_get(...)`` etc. There are two alternatives: route tables and route decorators. Route tables look like Django way:: async def handle_get(request): ... async def handle_post(request): ... app.router.add_routes([web.get('/get', handle_get), web.post('/post', handle_post), The snippet calls :meth:`~aiohttp.web.UrlDispatcher.add_routes` to register a list of *route definitions* (:class:`aiohttp.web.RouteDef` instances) created by :func:`aiohttp.web.get` or :func:`aiohttp.web.post` functions. .. seealso:: :ref:`aiohttp-web-route-def` reference. Route decorators are closer to Flask approach:: routes = web.RouteTableDef() @routes.get('/get') async def handle_get(request): ... @routes.post('/post') async def handle_post(request): ... app.router.add_routes(routes) It is also possible to use decorators with class-based views:: routes = web.RouteTableDef() @routes.view("/view") class MyView(web.View): async def get(self): ... async def post(self): ... app.router.add_routes(routes) The example creates a :class:`aiohttp.web.RouteTableDef` container first. The container is a list-like object with additional decorators :meth:`aiohttp.web.RouteTableDef.get`, :meth:`aiohttp.web.RouteTableDef.post` etc. for registering new routes. After filling the container :meth:`~aiohttp.web.UrlDispatcher.add_routes` is used for adding registered *route definitions* into application's router. .. seealso:: :ref:`aiohttp-web-route-table-def` reference. All tree ways (imperative calls, route tables and decorators) are equivalent, you could use what do you prefer or even mix them on your own. .. versionadded:: 2.3 JSON Response ------------- It is a common case to return JSON data in response, :mod:`aiohttp.web` provides a shortcut for returning JSON -- :func:`aiohttp.web.json_response`:: async def handler(request): data = {'some': 'data'} return web.json_response(data) The shortcut method returns :class:`aiohttp.web.Response` instance so you can for example set cookies before returning it from handler. User Sessions ------------- Often you need a container for storing user data across requests. The concept is usually called a *session*. :mod:`aiohttp.web` has no built-in concept of a *session*, however, there is a third-party library, :mod:`aiohttp_session`, that adds *session* support:: import asyncio import time import base64 from cryptography import fernet from aiohttp import web from aiohttp_session import setup, get_session, session_middleware from aiohttp_session.cookie_storage import EncryptedCookieStorage async def handler(request): session = await get_session(request) last_visit = session.get("last_visit") session["last_visit"] = time.time() text = "Last visited: {}".format(last_visit) return web.Response(text=text) async def make_app(): app = web.Application() # secret_key must be 32 url-safe base64-encoded bytes fernet_key = fernet.Fernet.generate_key() secret_key = base64.urlsafe_b64decode(fernet_key) setup(app, EncryptedCookieStorage(secret_key)) app.add_routes([web.get('/', handler)]) return app web.run_app(make_app()) .. _aiohttp-web-forms: HTTP Forms ---------- HTTP Forms are supported out of the box. If form's method is ``"GET"`` (``
    ``) use :attr:`aiohttp.web.BaseRequest.query` for getting form data. To access form data with ``"POST"`` method use :meth:`aiohttp.web.BaseRequest.post` or :meth:`aiohttp.web.BaseRequest.multipart`. :meth:`aiohttp.web.BaseRequest.post` accepts both ``'application/x-www-form-urlencoded'`` and ``'multipart/form-data'`` form's data encoding (e.g. ````). It stores files data in temporary directory. If `client_max_size` is specified `post` raises `ValueError` exception. For efficiency use :meth:`aiohttp.web.BaseRequest.multipart`, It is especially effective for uploading large files (:ref:`aiohttp-web-file-upload`). Values submitted by the following form: .. code-block:: html
    could be accessed as:: async def do_login(request): data = await request.post() login = data['login'] password = data['password'] .. _aiohttp-web-file-upload: File Uploads ------------ :mod:`aiohttp.web` has built-in support for handling files uploaded from the browser. First, make sure that the HTML ``
    `` element has its *enctype* attribute set to ``enctype="multipart/form-data"``. As an example, here is a form that accepts an MP3 file: .. code-block:: html
    Then, in the :ref:`request handler ` you can access the file input field as a :class:`FileField` instance. :class:`FileField` is simply a container for the file as well as some of its metadata:: async def store_mp3_handler(request): # WARNING: don't do that if you plan to receive large files! data = await request.post() mp3 = data['mp3'] # .filename contains the name of the file in string format. filename = mp3.filename # .file contains the actual file data that needs to be stored somewhere. mp3_file = data['mp3'].file content = mp3_file.read() return web.Response(body=content, headers=MultiDict( {'CONTENT-DISPOSITION': mp3_file})) You might have noticed a big warning in the example above. The general issue is that :meth:`aiohttp.web.BaseRequest.post` reads the whole payload in memory, resulting in possible :abbr:`OOM (Out Of Memory)` errors. To avoid this, for multipart uploads, you should use :meth:`aiohttp.web.BaseRequest.multipart` which returns a :ref:`multipart reader `:: async def store_mp3_handler(request): reader = await request.multipart() # /!\ Don't forget to validate your inputs /!\ # reader.next() will `yield` the fields of your form field = await reader.next() assert field.name == 'name' name = await field.read(decode=True) field = await reader.next() assert field.name == 'mp3' filename = field.filename # You cannot rely on Content-Length if transfer is chunked. size = 0 with open(os.path.join('/spool/yarrr-media/mp3/', filename), 'wb') as f: while True: chunk = await field.read_chunk() # 8192 bytes by default. if not chunk: break size += len(chunk) f.write(chunk) return web.Response(text='{} sized of {} successfully stored' ''.format(filename, size)) .. _aiohttp-web-websockets: WebSockets ---------- :mod:`aiohttp.web` supports *WebSockets* out-of-the-box. To setup a *WebSocket*, create a :class:`WebSocketResponse` in a :ref:`request handler ` and then use it to communicate with the peer:: async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: # ws.__next__() automatically terminates the loop # after ws.close() or ws.exception() is called if msg.type == aiohttp.WSMsgType.TEXT: if msg.data == 'close': await ws.close() else: await ws.send_str(msg.data + '/answer') elif msg.type == aiohttp.WSMsgType.ERROR: print('ws connection closed with exception %s' % ws.exception()) print('websocket connection closed') return ws The handler should be registered as HTTP GET processor:: app.add_routes([web.get('/ws', websocket_handler)]) .. _aiohttp-web-redirects: Redirects --------- To redirect user to another endpoint - raise :class:`HTTPFound` with an absolute URL, relative URL or view name (the argument from router):: raise web.HTTPFound('/redirect') The following example shows redirect to view named 'login' in routes:: async def handler(request): location = request.app.router['login'].url_for() raise web.HTTPFound(location=location) router.add_get('/handler', handler) router.add_get('/login', login_handler, name='login') Example with login validation:: @aiohttp_jinja2.template('login.html') async def login(request): if request.method == 'POST': form = await request.post() error = validate_login(form) if error: return {'error': error} else: # login form is valid location = request.app.router['index'].url_for() raise web.HTTPFound(location=location) return {} app.router.add_get('/', index, name='index') app.router.add_get('/login', login, name='login') app.router.add_post('/login', login, name='login') ================================================ FILE: docs/web_reference.rst ================================================ .. currentmodule:: aiohttp.web .. _aiohttp-web-reference: Server Reference ================ .. _aiohttp-web-request: Request and Base Request ------------------------ The Request object contains all the information about an incoming HTTP request. :class:`BaseRequest` is used for :ref:`Low-Level Servers` (which have no applications, routers, signals and middlewares). :class:`Request` has an :attr:`Request.app` and :attr:`Request.match_info` attributes. A :class:`BaseRequest` / :class:`Request` are :obj:`dict` like objects, allowing them to be used for :ref:`sharing data` among :ref:`aiohttp-web-middlewares` and :ref:`aiohttp-web-signals` handlers. .. class:: BaseRequest :canonical: aiohttp.web_request.BaseRequest .. attribute:: version *HTTP version* of request, Read-only property. Returns :class:`aiohttp.protocol.HttpVersion` instance. .. attribute:: method *HTTP method*, read-only property. The value is upper-cased :class:`str` like ``"GET"``, ``"POST"``, ``"PUT"`` etc. .. attribute:: url A :class:`~yarl.URL` instance with absolute URL to resource (*scheme*, *host* and *port* are included). .. note:: In case of malformed request (e.g. without ``"HOST"`` HTTP header) the absolute url may be unavailable. .. attribute:: rel_url A :class:`~yarl.URL` instance with relative URL to resource (contains *path*, *query* and *fragment* parts only, *scheme*, *host* and *port* are excluded). The property is equal to ``.url.relative()`` but is always present. .. seealso:: A note from :attr:`url`. .. attribute:: scheme A string representing the scheme of the request. The scheme is ``'https'`` if transport for request handling is *SSL*, ``'http'`` otherwise. The value could be overridden by :meth:`~BaseRequest.clone`. Read-only :class:`str` property. .. versionchanged:: 2.3 *Forwarded* and *X-Forwarded-Proto* are not used anymore. Call ``.clone(scheme=new_scheme)`` for setting up the value explicitly. .. seealso:: :ref:`aiohttp-web-forwarded-support` .. attribute:: secure Shorthand for ``request.url.scheme == 'https'`` Read-only :class:`bool` property. .. seealso:: :attr:`scheme` .. attribute:: forwarded A tuple containing all parsed Forwarded header(s). Makes an effort to parse Forwarded headers as specified by :rfc:`7239`: - It adds one (immutable) dictionary per Forwarded ``field-value``, i.e. per proxy. The element corresponds to the data in the Forwarded ``field-value`` added by the first proxy encountered by the client. Each subsequent item corresponds to those added by later proxies. - It checks that every value has valid syntax in general as specified in :rfc:`7239#section-4`: either a ``token`` or a ``quoted-string``. - It un-escapes ``quoted-pairs``. - It does NOT validate 'by' and 'for' contents as specified in :rfc:`7239#section-6`. - It does NOT validate ``host`` contents (Host ABNF). - It does NOT validate ``proto`` contents for valid URI scheme names. Returns a tuple containing one or more ``MappingProxy`` objects .. seealso:: :attr:`scheme` .. seealso:: :attr:`host` .. attribute:: host Host name of the request, resolved in this order: - Overridden value by :meth:`~BaseRequest.clone` call. - *Host* HTTP header - :func:`socket.getfqdn` Read-only :class:`str` property. .. versionchanged:: 2.3 *Forwarded* and *X-Forwarded-Host* are not used anymore. Call ``.clone(host=new_host)`` for setting up the value explicitly. .. seealso:: :ref:`aiohttp-web-forwarded-support` .. attribute:: remote Originating IP address of a client initiated HTTP request. The IP is resolved through the following headers, in this order: - Overridden value by :meth:`~BaseRequest.clone` call. - Peer name of opened socket. Read-only :class:`str` property. Call ``.clone(remote=new_remote)`` for setting up the value explicitly. .. versionadded:: 2.3 .. seealso:: :ref:`aiohttp-web-forwarded-support` .. attribute:: client_max_size The maximum size of the request body. The value could be overridden by :meth:`~BaseRequest.clone`. Read-only :class:`int` property. .. attribute:: path_qs The URL including PATH_INFO and the query string. e.g., ``/app/blog?id=10`` Read-only :class:`str` property. .. attribute:: path The URL including *PATH INFO* without the host or scheme. e.g., ``/app/blog``. The path is URL-decoded. For raw path info see :attr:`raw_path`. Read-only :class:`str` property. .. attribute:: raw_path The URL including raw *PATH INFO* without the host or scheme. Warning, the path may be URL-encoded and may contain invalid URL characters, e.g. ``/my%2Fpath%7Cwith%21some%25strange%24characters``. For URL-decoded version please take a look on :attr:`path`. Read-only :class:`str` property. .. attribute:: query A multidict with all the variables in the query string. Read-only :class:`~multidict.MultiDictProxy` lazy property. .. attribute:: query_string The query string in the URL, e.g., ``id=10`` Read-only :class:`str` property. .. attribute:: headers A case-insensitive multidict proxy with all headers. Read-only :class:`~multidict.CIMultiDictProxy` property. .. attribute:: raw_headers HTTP headers of response as unconverted bytes, a sequence of ``(key, value)`` pairs. .. attribute:: keep_alive ``True`` if keep-alive connection enabled by HTTP client and protocol version supports it, otherwise ``False``. Read-only :class:`bool` property. .. attribute:: transport A :ref:`transport` used to process request. Read-only property. The property can be used, for example, for getting IP address of client's peer:: peername = request.transport.get_extra_info('peername') if peername is not None: host, port = peername .. attribute:: cookies A read-only dictionary-like object containing the request's cookies. Read-only :class:`~types.MappingProxyType` property. .. attribute:: content A :class:`~aiohttp.StreamReader` instance, input stream for reading request's *BODY*. Read-only property. .. attribute:: body_exists Return ``True`` if request has *HTTP BODY*, ``False`` otherwise. Read-only :class:`bool` property. .. versionadded:: 2.3 .. attribute:: can_read_body Return ``True`` if request's *HTTP BODY* can be read, ``False`` otherwise. Read-only :class:`bool` property. .. versionadded:: 2.3 .. attribute:: content_type Read-only property with *content* part of *Content-Type* header. Returns :class:`str` like ``'text/html'`` .. note:: Returns value is ``'application/octet-stream'`` if no Content-Type header present in HTTP headers according to :rfc:`2616` .. attribute:: charset Read-only property that specifies the *encoding* for the request's BODY. The value is parsed from the *Content-Type* HTTP header. Returns :class:`str` like ``'utf-8'`` or ``None`` if *Content-Type* has no charset information. .. attribute:: content_length Read-only property that returns length of the request's BODY. The value is parsed from the *Content-Length* HTTP header. Returns :class:`int` or ``None`` if *Content-Length* is absent. .. attribute:: http_range Read-only property that returns information about *Range* HTTP header. Returns a :class:`slice` where ``.start`` is *left inclusive bound*, ``.stop`` is *right exclusive bound* and ``.step`` is ``1``. The property might be used in two manners: 1. Attribute-access style (example assumes that both left and right borders are set, the real logic for case of open bounds is more complex):: rng = request.http_range with open(filename, 'rb') as f: f.seek(rng.start) return f.read(rng.stop-rng.start) 2. Slice-style:: return buffer[request.http_range] .. attribute:: if_modified_since Read-only property that returns the date specified in the *If-Modified-Since* header. Returns :class:`datetime.datetime` or ``None`` if *If-Modified-Since* header is absent or is not a valid HTTP date. .. attribute:: if_unmodified_since Read-only property that returns the date specified in the *If-Unmodified-Since* header. Returns :class:`datetime.datetime` or ``None`` if *If-Unmodified-Since* header is absent or is not a valid HTTP date. .. versionadded:: 3.1 .. attribute:: if_match Read-only property that returns :class:`~aiohttp.ETag` objects specified in the *If-Match* header. Returns :class:`tuple` of :class:`~aiohttp.ETag` or ``None`` if *If-Match* header is absent. .. versionadded:: 3.8 .. attribute:: if_none_match Read-only property that returns :class:`~aiohttp.ETag` objects specified *If-None-Match* header. Returns :class:`tuple` of :class:`~aiohttp.ETag` or ``None`` if *If-None-Match* header is absent. .. versionadded:: 3.8 .. attribute:: if_range Read-only property that returns the date specified in the *If-Range* header. Returns :class:`datetime.datetime` or ``None`` if *If-Range* header is absent or is not a valid HTTP date. .. versionadded:: 3.1 .. method:: clone(*, method=..., rel_url=..., headers=...) Clone itself with replacement some attributes. Creates and returns a new instance of Request object. If no parameters are given, an exact copy is returned. If a parameter is not passed, it will reuse the one from the current request object. :param str method: http method :param rel_url: url to use, :class:`str` or :class:`~yarl.URL` :param headers: :class:`~multidict.CIMultiDict` or compatible headers container. :return: a cloned :class:`Request` instance. .. method:: get_extra_info(name, default=None) Reads extra information from the protocol's transport. If no value associated with ``name`` is found, ``default`` is returned. See :meth:`asyncio.BaseTransport.get_extra_info` :param str name: The key to look up in the transport extra information. :param default: Default value to be used when no value for ``name`` is found (default is ``None``). .. versionadded:: 3.7 .. method:: read() :async: Read request body, returns :class:`bytes` object with body content. .. note:: The method **does** store read data internally, subsequent :meth:`~aiohttp.web.BaseRequest.read` call will return the same value. .. method:: text() :async: Read request body, decode it using :attr:`charset` encoding or ``UTF-8`` if no encoding was specified in *MIME-type*. Returns :class:`str` with body content. .. note:: The method **does** store read data internally, subsequent :meth:`~aiohttp.web.BaseRequest.text` call will return the same value. .. method:: json(*, loads=json.loads, \ content_type='application/json') :async: Read request body decoded as *json*. If request's content-type does not match `content_type` parameter, :exc:`aiohttp.web.HTTPBadRequest` get raised. To disable content type check pass ``None`` value. :param collections.abc.Callable loads: any :term:`callable` that accepts :class:`str` and returns :class:`dict` with parsed JSON (:func:`json.loads` by default). :param str content_type: expected value of Content-Type header or ``None`` ('application/json' by default) .. note:: The method **does** store read data internally, subsequent :meth:`~aiohttp.web.BaseRequest.json` call will return the same value. .. method:: multipart() :async: Returns :class:`aiohttp.MultipartReader` which processes incoming *multipart* request. The method is just a boilerplate :ref:`coroutine ` implemented as:: async def multipart(self, *, reader=aiohttp.multipart.MultipartReader): return reader(self.headers, self._payload) This method is a coroutine for consistency with the else reader methods. .. warning:: The method **does not** store read data internally. That means once you exhausts multipart reader, you cannot get the request payload one more time. .. seealso:: :ref:`aiohttp-multipart` .. versionchanged:: 3.4 Dropped *reader* parameter. .. method:: post() :async: A :ref:`coroutine ` that reads POST parameters from request body. Returns :class:`~multidict.MultiDictProxy` instance filled with parsed data. If :attr:`method` is not *POST*, *PUT*, *PATCH*, *TRACE* or *DELETE* or :attr:`content_type` is not empty or *application/x-www-form-urlencoded* or *multipart/form-data* returns empty multidict. .. note:: The method **does** store read data internally, subsequent :meth:`~aiohttp.web.BaseRequest.post` call will return the same value. .. method:: release() :async: Release request. Eat unread part of HTTP BODY if present. .. note:: User code may never call :meth:`~aiohttp.web.BaseRequest.release`, all required work will be processed by :mod:`aiohttp.web` internal machinery. .. class:: Request :canonical: aiohttp.web_request.Request A request used for receiving request's information by *web handler*. Every :ref:`handler` accepts a request instance as the first positional parameter. The class in derived from :class:`BaseRequest`, shares all parent's attributes and methods but has a couple of additional properties: .. attribute:: match_info Read-only property with :class:`~aiohttp.abc.AbstractMatchInfo` instance for result of route resolving. .. note:: Exact type of property depends on used router. If ``app.router`` is :class:`UrlDispatcher` the property contains :class:`UrlMappingMatchInfo` instance. .. attribute:: app An :class:`Application` instance used to call :ref:`request handler `, Read-only property. .. attribute:: config_dict A :class:`aiohttp.ChainMapProxy` instance for mapping all properties from the current application returned by :attr:`app` property and all its parents. .. seealso:: :ref:`aiohttp-web-data-sharing-app-config` .. versionadded:: 3.2 .. note:: You should never create the :class:`Request` instance manually -- :mod:`aiohttp.web` does it for you. But :meth:`~BaseRequest.clone` may be used for cloning *modified* request copy with changed *path*, *method* etc. .. class:: RequestKey(name, t) :canonical: aiohttp.helpers.RequestKey Keys for use in :class:`Request`. See :class:`AppKey` for more details. .. _aiohttp-web-response: Response classes ---------------- For now, :mod:`aiohttp.web` has three classes for the *HTTP response*: :class:`StreamResponse`, :class:`Response` and :class:`FileResponse`. Usually you need to use the second one. :class:`StreamResponse` is intended for streaming data, while :class:`Response` contains *HTTP BODY* as an attribute and sends own content as single piece with the correct *Content-Length HTTP header*. For sake of design decisions :class:`Response` is derived from :class:`StreamResponse` parent class. The response supports *keep-alive* handling out-of-the-box if *request* supports it. You can disable *keep-alive* by :meth:`~StreamResponse.force_close` though. The common case for sending an answer from :ref:`web-handler` is returning a :class:`Response` instance:: async def handler(request): return Response(text="All right!") Response classes are :obj:`dict` like objects, allowing them to be used for :ref:`sharing data` among :ref:`aiohttp-web-middlewares` and :ref:`aiohttp-web-signals` handlers:: resp['key'] = value .. versionadded:: 3.0 Dict-like interface support. .. class:: StreamResponse(*, status=200, reason=None) :canonical: aiohttp.web_response.StreamResponse The base class for the *HTTP response* handling. Contains methods for setting *HTTP response headers*, *cookies*, *response status code*, writing *HTTP response BODY* and so on. The most important thing you should know about *response* --- it is *Finite State Machine*. That means you can do any manipulations with *headers*, *cookies* and *status code* only before :meth:`prepare` coroutine is called. Once you call :meth:`prepare` any change of the *HTTP header* part will raise :exc:`RuntimeError` exception. Any :meth:`write` call after :meth:`write_eof` is also forbidden. :param int status: HTTP status code, ``200`` by default. :param str reason: HTTP reason. If param is ``None`` reason will be calculated basing on *status* parameter. Otherwise pass :class:`str` with arbitrary *status* explanation.. .. attribute:: prepared Read-only :class:`bool` property, ``True`` if :meth:`prepare` has been called, ``False`` otherwise. .. attribute:: task A task that serves HTTP request handling. May be useful for graceful shutdown of long-running requests (streaming, long polling or web-socket). .. attribute:: status Read-only property for *HTTP response status code*, :class:`int`. ``200`` (OK) by default. .. attribute:: reason Read-only property for *HTTP response reason*, :class:`str`. .. method:: set_status(status, reason=None) Set :attr:`status` and :attr:`reason`. *reason* value is auto calculated if not specified (``None``). .. attribute:: keep_alive Read-only property, copy of :attr:`aiohttp.web.BaseRequest.keep_alive` by default. Can be switched to ``False`` by :meth:`force_close` call. .. method:: force_close Disable :attr:`keep_alive` for connection. There are no ways to enable it back. .. attribute:: compression Read-only :class:`bool` property, ``True`` if compression is enabled. ``False`` by default. .. seealso:: :meth:`enable_compression` .. method:: enable_compression(force=None, strategy=None) Enable compression. When *force* is unset compression encoding is selected based on the request's *Accept-Encoding* header. *Accept-Encoding* is not checked if *force* is set to a :class:`ContentCoding`. *strategy* accepts a :mod:`zlib` compression strategy. See :func:`zlib.compressobj` for possible values, or refer to the docs for the zlib of your using, should you use :func:`aiohttp.set_zlib_backend` to change zlib backend. If ``None``, the default value adopted by your zlib backend will be used where applicable. .. seealso:: :attr:`compression` .. attribute:: chunked Read-only property, indicates if chunked encoding is on. Can be enabled by :meth:`enable_chunked_encoding` call. .. seealso:: :attr:`enable_chunked_encoding` .. method:: enable_chunked_encoding() Enables :attr:`chunked` encoding for response. There are no ways to disable it back. With enabled :attr:`chunked` encoding each :meth:`write` operation encoded in separate chunk. .. warning:: chunked encoding can be enabled for ``HTTP/1.1`` only. Setting up both :attr:`content_length` and chunked encoding is mutually exclusive. .. seealso:: :attr:`chunked` .. attribute:: headers :class:`~multidict.CIMultiDict` instance for *outgoing* *HTTP headers*. .. attribute:: cookies An instance of :class:`http.cookies.SimpleCookie` for *outgoing* cookies. .. warning:: Direct setting up *Set-Cookie* header may be overwritten by explicit calls to cookie manipulation. We are encourage using of :attr:`cookies` and :meth:`set_cookie`, :meth:`del_cookie` for cookie manipulations. .. method:: set_cookie(name, value, *, path='/', expires=None, \ domain=None, max_age=None, \ secure=None, httponly=None, samesite=None, \ partitioned=None) Convenient way for setting :attr:`cookies`, allows to specify some additional properties like *max_age* in a single call. :param str name: cookie name :param str value: cookie value (will be converted to :class:`str` if value has another type). :param expires: expiration date (optional) :param str domain: cookie domain (optional) :param int max_age: defines the lifetime of the cookie, in seconds. The delta-seconds value is a decimal non- negative integer. After delta-seconds seconds elapse, the client should discard the cookie. A value of zero means the cookie should be discarded immediately. (optional) :param str path: specifies the subset of URLs to which this cookie applies. (optional, ``'/'`` by default) :param bool secure: attribute (with no value) directs the user agent to use only (unspecified) secure means to contact the origin server whenever it sends back this cookie. The user agent (possibly under the user's control) may determine what level of security it considers appropriate for "secure" cookies. The *secure* should be considered security advice from the server to the user agent, indicating that it is in the session's interest to protect the cookie contents. (optional) :param bool httponly: ``True`` if the cookie HTTP only (optional) :param str samesite: Asserts that a cookie must not be sent with cross-origin requests, providing some protection against cross-site request forgery attacks. Generally the value should be one of: ``None``, ``Lax`` or ``Strict``. (optional) .. versionadded:: 3.7 :param bool partitioned: ``True`` to set a partitioned cookie. Available in Python 3.14+. (optional) .. versionadded:: 3.12 .. method:: del_cookie(name, *, path='/', domain=None) Deletes cookie. :param str name: cookie name :param str domain: optional cookie domain :param str path: optional cookie path, ``'/'`` by default .. attribute:: content_length *Content-Length* for outgoing response. .. attribute:: content_type *Content* part of *Content-Type* for outgoing response. .. attribute:: charset *Charset* aka *encoding* part of *Content-Type* for outgoing response. The value converted to lower-case on attribute assigning. .. attribute:: last_modified *Last-Modified* header for outgoing response. This property accepts raw :class:`str` values, :class:`datetime.datetime` objects, Unix timestamps specified as an :class:`int` or a :class:`float` object, and the value ``None`` to unset the header. .. attribute:: etag *ETag* header for outgoing response. This property accepts raw :class:`str` values, :class:`~aiohttp.ETag` objects and the value ``None`` to unset the header. In case of :class:`str` input, etag is considered as strong by default. **Do not** use double quotes ``"`` in the etag value, they will be added automatically. .. versionadded:: 3.8 .. method:: prepare(request) :async: :param aiohttp.web.Request request: HTTP request object, that the response answers. Send *HTTP header*. You should not change any header data after calling this method. The coroutine calls :attr:`~aiohttp.web.Application.on_response_prepare` signal handlers after default headers have been computed and directly before headers are sent. .. method:: write(data) :async: Send byte-ish data as the part of *response BODY*:: await resp.write(data) :meth:`prepare` must be invoked before the call. Raises :exc:`TypeError` if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview` instance. Raises :exc:`RuntimeError` if :meth:`prepare` has not been called. Raises :exc:`RuntimeError` if :meth:`write_eof` has been called. .. method:: write_eof() :async: A :ref:`coroutine` *may* be called as a mark of the *HTTP response* processing finish. *Internal machinery* will call this method at the end of the request processing if needed. After :meth:`write_eof` call any manipulations with the *response* object are forbidden. .. class:: Response(*, body=None, status=200, reason=None, text=None, \ headers=None, content_type=None, charset=None, \ zlib_executor_size=sentinel, zlib_executor=None) :canonical: aiohttp.web_response.Response The most usable response class, inherited from :class:`StreamResponse`. Accepts *body* argument for setting the *HTTP response BODY*. The actual :attr:`body` sending happens in overridden :meth:`~StreamResponse.write_eof`. :param bytes body: response's BODY :param int status: HTTP status code, 200 OK by default. :param collections.abc.Mapping headers: HTTP headers that should be added to response's ones. :param str text: response's BODY :param str content_type: response's content type. ``'text/plain'`` if *text* is passed also, ``'application/octet-stream'`` otherwise. :param str charset: response's charset. ``'utf-8'`` if *text* is passed also, ``None`` otherwise. :param int zlib_executor_size: length in bytes which will trigger zlib compression of body to happen in an executor .. versionadded:: 3.5 :param int zlib_executor: executor to use for zlib compression .. versionadded:: 3.5 .. attribute:: body Read-write attribute for storing response's content aka BODY, :class:`bytes`. Assigning :class:`str` to :attr:`body` will make the :attr:`body` type of :class:`aiohttp.payload.StringPayload`, which tries to encode the given data based on *Content-Type* HTTP header, while defaulting to ``UTF-8``. .. attribute:: text Read-write attribute for storing response's :attr:`~aiohttp.StreamResponse.body`, represented as :class:`str`. .. class:: FileResponse(*, path, chunk_size=256*1024, status=200, reason=None, headers=None) :canonical: aiohttp.web_fileresponse.FileResponse The response class used to send files, inherited from :class:`StreamResponse`. Supports the ``Content-Range`` and ``If-Range`` HTTP Headers in requests. The actual :attr:`body` sending happens in overridden :meth:`~StreamResponse.prepare`. :param path: Path to file. Accepts both :class:`str` and :class:`pathlib.Path`. :param int chunk_size: Chunk size in bytes which will be passed into :meth:`io.RawIOBase.read` in the event that the ``sendfile`` system call is not supported. :param int status: HTTP status code, ``200`` by default. :param str reason: HTTP reason. If param is ``None`` reason will be calculated basing on *status* parameter. Otherwise pass :class:`str` with arbitrary *status* explanation.. :param collections.abc.Mapping headers: HTTP headers that should be added to response's ones. The ``Content-Type`` response header will be overridden if provided. .. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, \ autoclose=True, autoping=True, heartbeat=None, \ protocols=(), compress=True, max_msg_size=4194304, \ writer_limit=65536, decode_text=True) :canonical: aiohttp.web_ws.WebSocketResponse Class for handling server-side websockets, inherited from :class:`StreamResponse`. After starting (by :meth:`prepare` call) the response you cannot use :meth:`~StreamResponse.write` method but should to communicate with websocket client by :meth:`send_str`, :meth:`receive` and others. To enable back-pressure from slow websocket clients treat methods :meth:`ping`, :meth:`pong`, :meth:`send_str`, :meth:`send_bytes`, :meth:`send_json`, :meth:`send_frame` as coroutines. By default write buffer size is set to 64k. :param bool autoping: Automatically send :const:`~aiohttp.WSMsgType.PONG` on :const:`~aiohttp.WSMsgType.PING` message from client, and handle :const:`~aiohttp.WSMsgType.PONG` responses from client. Note that server does not send :const:`~aiohttp.WSMsgType.PING` requests, you need to do this explicitly using :meth:`ping` method. :param float heartbeat: Send `ping` message every `heartbeat` seconds and wait `pong` response, close connection if `pong` response is not received. The timer is reset on any inbound data reception (coalesced per event loop iteration). :param float timeout: Timeout value for the ``close`` operation. After sending the close websocket message, ``close`` waits for ``timeout`` seconds for a response. Default value is ``10.0`` (10 seconds for ``close`` operation) :param float receive_timeout: Timeout value for `receive` operations. Default value is :data:`None` (no timeout for receive operation) :param bool compress: Enable per-message deflate extension support. :data:`False` for disabled, default value is :data:`True`. :param int max_msg_size: maximum size of read websocket message, 4 MB by default. To disable the size limit use ``0``. .. versionadded:: 3.3 :param bool autoclose: Close connection when the client sends a :const:`~aiohttp.WSMsgType.CLOSE` message, ``True`` by default. If set to ``False``, the connection is not closed and the caller is responsible for calling ``request.transport.close()`` to avoid leaking resources. :param int writer_limit: maximum size of write buffer, 64 KB by default. Once the buffer is full, the websocket will pause to drain the buffer. .. versionadded:: 3.11 :param bool decode_text: If ``True`` (default), TEXT messages are decoded to strings. If ``False``, TEXT messages are returned as raw bytes, which can improve performance when using JSON parsers like ``orjson`` that accept bytes directly. .. versionadded:: 3.14 The class supports ``async for`` statement for iterating over incoming messages:: ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: print(msg.data) .. method:: prepare(request) :async: Starts websocket. After the call you can use websocket methods. :param aiohttp.web.Request request: HTTP request object, that the response answers. :raises HTTPException: if websocket handshake has failed. .. method:: can_prepare(request) Performs checks for *request* data to figure out if websocket can be started on the request. If :meth:`can_prepare` call is success then :meth:`prepare` will success too. :param aiohttp.web.Request request: HTTP request object, that the response answers. :return: :class:`WebSocketReady` instance. :attr:`WebSocketReady.ok` is ``True`` on success, :attr:`WebSocketReady.protocol` is websocket subprotocol which is passed by client and accepted by server (one of *protocols* sequence from :class:`WebSocketResponse` ctor). :attr:`WebSocketReady.protocol` may be ``None`` if client and server subprotocols are not overlapping. .. note:: The method never raises exception. .. attribute:: closed Read-only property, ``True`` if connection has been closed or in process of closing. :const:`~aiohttp.WSMsgType.CLOSE` message has been received from peer. .. attribute:: prepared Read-only :class:`bool` property, ``True`` if :meth:`prepare` has been called, ``False`` otherwise. .. attribute:: close_code Read-only property, close code from peer. It is set to ``None`` on opened connection. .. attribute:: ws_protocol Websocket *subprotocol* chosen after :meth:`start` call. May be ``None`` if server and client protocols are not overlapping. .. method:: get_extra_info(name, default=None) Reads optional extra information from the writer's transport. If no value associated with ``name`` is found, ``default`` is returned. See :meth:`asyncio.BaseTransport.get_extra_info` :param str name: The key to look up in the transport extra information. :param default: Default value to be used when no value for ``name`` is found (default is ``None``). .. method:: exception() Returns last occurred exception or None. .. method:: ping(message=b'') :async: Send :const:`~aiohttp.WSMsgType.PING` to peer. :param message: optional payload of *ping* message, :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. :raise RuntimeError: if the connections is not started. :raise aiohttp.ClientConnectionResetError: if the connection is closing. .. versionchanged:: 3.0 The method is converted into :term:`coroutine` .. method:: pong(message=b'') :async: Send *unsolicited* :const:`~aiohttp.WSMsgType.PONG` to peer. :param message: optional payload of *pong* message, :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. :raise RuntimeError: if the connections is not started. :raise aiohttp.ClientConnectionResetError: if the connection is closing. .. versionchanged:: 3.0 The method is converted into :term:`coroutine` .. method:: send_str(data, compress=None) :async: Send *data* to peer as :const:`~aiohttp.WSMsgType.TEXT` message. :param str data: data to send. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :raise RuntimeError: if the connection is not started. :raise TypeError: if data is not :class:`str` :raise aiohttp.ClientConnectionResetError: if the connection is closing. .. versionchanged:: 3.0 The method is converted into :term:`coroutine`, *compress* parameter added. .. method:: send_bytes(data, compress=None) :async: Send *data* to peer as :const:`~aiohttp.WSMsgType.BINARY` message. :param data: data to send. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :raise RuntimeError: if the connection is not started. :raise TypeError: if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview`. :raise aiohttp.ClientConnectionResetError: if the connection is closing. .. versionchanged:: 3.0 The method is converted into :term:`coroutine`, *compress* parameter added. .. method:: send_json(data, compress=None, *, dumps=json.dumps) :async: Send *data* to peer as JSON string. :param data: data to send. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :param collections.abc.Callable dumps: any :term:`callable` that accepts an object and returns a JSON string (:func:`json.dumps` by default). :raise RuntimeError: if the connection is not started. :raise ValueError: if data is not serializable object :raise TypeError: if value returned by ``dumps`` param is not :class:`str` :raise aiohttp.ClientConnectionResetError: if the connection is closing. .. versionchanged:: 3.0 The method is converted into :term:`coroutine`, *compress* parameter added. .. method:: send_json_bytes(data, compress=None, *, dumps) :async: Send *data* to peer as a JSON binary frame using a bytes-returning encoder. :param data: data to send. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :param collections.abc.Callable dumps: any :term:`callable` that accepts an object and returns JSON as :class:`bytes` (e.g. ``orjson.dumps``). :raise RuntimeError: if the connection is not started. :raise ValueError: if data is not serializable object :raise TypeError: if value returned by ``dumps`` param is not :class:`bytes` .. method:: send_frame(message, opcode, compress=None) :async: Send a :const:`~aiohttp.WSMsgType` message *message* to peer. This method is low-level and should be used with caution as it only accepts bytes which must conform to the correct message type for *message*. It is recommended to use the :meth:`send_str`, :meth:`send_bytes` or :meth:`send_json` methods instead of this method. The primary use case for this method is to send bytes that are have already been encoded without having to decode and re-encode them. :param bytes message: message to send. :param ~aiohttp.WSMsgType opcode: opcode of the message. :param int compress: sets specific level of compression for single message, ``None`` for not overriding per-socket setting. :raise RuntimeError: if the connection is not started. :raise aiohttp.ClientConnectionResetError: if the connection is closing. .. versionadded:: 3.11 .. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True) :async: A :ref:`coroutine` that initiates closing handshake by sending :const:`~aiohttp.WSMsgType.CLOSE` message. It is safe to call `close()` from different task. :param int code: closing code. See also :class:`~aiohttp.WSCloseCode`. :param message: optional payload of *close* message, :class:`str` (converted to *UTF-8* encoded bytes) or :class:`bytes`. :param bool drain: drain outgoing buffer before closing connection. :raise RuntimeError: if connection is not started .. method:: receive(timeout=None) :async: A :ref:`coroutine` that waits upcoming *data* message from peer and returns it. The coroutine implicitly handles :const:`~aiohttp.WSMsgType.PING`, :const:`~aiohttp.WSMsgType.PONG` and :const:`~aiohttp.WSMsgType.CLOSE` without returning the message. It process *ping-pong game* and performs *closing handshake* internally. .. note:: Can only be called by the request handling task. :param timeout: timeout for `receive` operation. timeout value overrides response`s receive_timeout attribute. :return: :class:`~aiohttp.WSMessage` :raise RuntimeError: if connection is not started :raise asyncio.TimeoutError: if timeout expires before receiving a message .. method:: receive_str(*, timeout=None) :async: A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is :const:`~aiohttp.WSMsgType.TEXT`. .. note:: Can only be called by the request handling task. :param timeout: timeout for `receive` operation. timeout value overrides response`s receive_timeout attribute. :return str: peer's message content. :raise aiohttp.WSMessageTypeError: if message is not :const:`~aiohttp.WSMsgType.TEXT`. :raise asyncio.TimeoutError: if timeout expires before receiving a message .. method:: receive_bytes(*, timeout=None) :async: A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is :const:`~aiohttp.WSMsgType.BINARY`. .. note:: Can only be called by the request handling task. :param timeout: timeout for `receive` operation. timeout value overrides response`s receive_timeout attribute. :return bytes: peer's message content. :raise aiohttp.WSMessageTypeError: if message is not :const:`~aiohttp.WSMsgType.BINARY`. :raise asyncio.TimeoutError: if timeout expires before receiving a message .. method:: receive_json(*, loads=json.loads, timeout=None) :async: A :ref:`coroutine` that calls :meth:`receive_str` and loads the JSON string to a Python dict. .. note:: Can only be called by the request handling task. :param collections.abc.Callable loads: any :term:`callable` that accepts :class:`str` and returns :class:`dict` with parsed JSON (:func:`json.loads` by default). :param timeout: timeout for `receive` operation. timeout value overrides response`s receive_timeout attribute. :return dict: loaded JSON content :raise TypeError: if message is :const:`~aiohttp.WSMsgType.BINARY`. :raise ValueError: if message is not valid JSON. :raise asyncio.TimeoutError: if timeout expires before receiving a message .. seealso:: :ref:`WebSockets handling` .. class:: WebSocketReady :canonical: aiohttp.web_ws.WebSocketReady A named tuple for returning result from :meth:`WebSocketResponse.can_prepare`. Has :class:`bool` check implemented, e.g.:: if not await ws.can_prepare(...): cannot_start_websocket() .. attribute:: ok ``True`` if websocket connection can be established, ``False`` otherwise. .. attribute:: protocol :class:`str` represented selected websocket sub-protocol. .. seealso:: :meth:`WebSocketResponse.can_prepare` .. function:: json_response([data], *, text=None, body=None, \ status=200, reason=None, headers=None, \ content_type='application/json', \ dumps=json.dumps) :canonical: aiohttp.web_response.json_response Return :class:`Response` with predefined ``'application/json'`` content type and *data* encoded by ``dumps`` parameter (:func:`json.dumps` by default). .. function:: json_bytes_response([data], *, dumps, body=None, \ status=200, reason=None, headers=None, \ content_type='application/json') Return :class:`Response` with predefined ``'application/json'`` content type and *data* encoded by ``dumps`` parameter which must return :class:`bytes` directly (e.g. ``orjson.dumps``). Use this when your JSON encoder returns :class:`bytes` instead of :class:`str`, avoiding the :class:`str`-to-:class:`bytes` encoding overhead. .. class:: ResponseKey(name, t) :canonical: aiohttp.helpers.ResponseKey Keys for use in :class:`Response`. See :class:`AppKey` for more details. .. _aiohttp-web-app-and-router: Application and Router ---------------------- .. class:: Application(*, logger=, middlewares=(), \ handler_args=None, client_max_size=1024**2, \ debug=...) :canonical: aiohttp.web_app.Application Application is a synonym for web-server. To get a fully working example, you have to make an *application*, register supported urls in the *router* and pass it to :func:`aiohttp.web.run_app` or :class:`aiohttp.web.AppRunner`. *Application* contains a *router* instance and a list of callbacks that will be called during application finishing. This class is a :obj:`dict`-like object, so you can use it for :ref:`sharing data` globally by storing arbitrary properties for later access from a :ref:`handler` via the :attr:`Request.app` property:: app = Application() database = AppKey("database", AsyncEngine) app[database] = await create_async_engine(db_url) async def handler(request): async with request.app[database].begin() as conn: await conn.execute("DELETE * FROM table") Although it` is a :obj:`dict`-like object, it can't be duplicated like one using :meth:`~aiohttp.web.Application.copy`. The class inherits :class:`dict`. :param logger: :class:`logging.Logger` instance for storing application logs. By default the value is ``logging.getLogger("aiohttp.web")`` :param middlewares: :class:`list` of middleware factories, see :ref:`aiohttp-web-middlewares` for details. :param handler_args: dict-like object that overrides keyword arguments of :class:`AppRunner` constructor. :param client_max_size: client's maximum size in a request, in bytes. If a POST request exceeds this value, it raises an `HTTPRequestEntityTooLarge` exception. :param debug: Switches debug mode. .. deprecated:: 3.5 The argument does nothing starting from 4.0, use asyncio :ref:`asyncio-debug-mode` instead. .. attribute:: router Read-only property that returns *router instance*. .. attribute:: logger :class:`logging.Logger` instance for storing application logs. .. attribute:: debug Boolean value indicating whether the debug mode is turned on or off. .. deprecated:: 3.5 Use asyncio :ref:`asyncio-debug-mode` instead. .. attribute:: on_response_prepare A :class:`~aiosignal.Signal` that is fired near the end of :meth:`StreamResponse.prepare` with parameters *request* and *response*. It can be used, for example, to add custom headers to each response, or to modify the default headers computed by the application, directly before sending the headers to the client. Signal handlers should have the following signature:: async def on_prepare(request, response): pass .. note:: The headers are written immediately after these callbacks are run. Therefore, if you modify the content of the response, you may need to adjust the `Content-Length` header or similar to match. Aiohttp will not make any updates to the headers from this point. .. attribute:: on_startup A :class:`~aiosignal.Signal` that is fired on application start-up. Subscribers may use the signal to run background tasks in the event loop along with the application's request handler just after the application start-up. Signal handlers should have the following signature:: async def on_startup(app): pass .. seealso:: :ref:`aiohttp-web-signals` and :ref:`aiohttp-web-cleanup-ctx`. .. attribute:: on_shutdown A :class:`~aiosignal.Signal` that is fired on application shutdown. Subscribers may use the signal for gracefully closing long running connections, e.g. websockets and data streaming. Signal handlers should have the following signature:: async def on_shutdown(app): pass It's up to end user to figure out which :term:`web-handler`\s are still alive and how to finish them properly. We suggest keeping a list of long running handlers in :class:`Application` dictionary. .. seealso:: :ref:`aiohttp-web-graceful-shutdown` and :attr:`on_cleanup`. .. attribute:: on_cleanup A :class:`~aiosignal.Signal` that is fired on application cleanup. Subscribers may use the signal for gracefully closing connections to database server etc. Signal handlers should have the following signature:: async def on_cleanup(app): pass .. seealso:: :ref:`aiohttp-web-signals` and :attr:`on_shutdown`. .. attribute:: cleanup_ctx A list of *context generators* for *startup*/*cleanup* handling. Signal handlers should have the following signature:: @contextlib.asynccontextmanager async def context(app: web.Application) -> AsyncIterator[None]: # do startup stuff yield # do cleanup .. versionadded:: 3.1 .. seealso:: :ref:`aiohttp-web-cleanup-ctx`. .. method:: add_subapp(prefix, subapp) Register nested sub-application under given path *prefix*. In resolving process if request's path starts with *prefix* then further resolving is passed to *subapp*. :param str prefix: path's prefix for the resource. :param Application subapp: nested application attached under *prefix*. :returns: a :class:`PrefixedSubAppResource` instance. .. method:: add_domain(domain, subapp) Register nested sub-application that serves the domain name or domain name mask. In resolving process if request.headers['host'] matches the pattern *domain* then further resolving is passed to *subapp*. .. warning:: Registering many domains using this method may cause performance issues with handler routing. If you have a substantial number of applications for different domains, you may want to consider using a reverse proxy (such as Nginx) to handle routing to different apps, rather that registering them as sub-applications. :param str domain: domain or mask of domain for the resource. :param Application subapp: nested application. :returns: a :class:`~aiohttp.web.MatchedSubAppResource` instance. .. method:: add_routes(routes_table) Register route definitions from *routes_table*. The table is a :class:`list` of :class:`RouteDef` items or :class:`RouteTableDef`. :returns: :class:`list` of registered :class:`AbstractRoute` instances. The method is a shortcut for ``app.router.add_routes(routes_table)``, see also :meth:`UrlDispatcher.add_routes`. .. versionadded:: 3.1 .. versionchanged:: 3.7 Return value updated from ``None`` to :class:`list` of :class:`AbstractRoute` instances. .. method:: startup() :async: A :ref:`coroutine` that will be called along with the application's request handler. The purpose of the method is calling :attr:`on_startup` signal handlers. .. method:: shutdown() :async: A :ref:`coroutine` that should be called on server stopping but before :meth:`cleanup`. The purpose of the method is calling :attr:`on_shutdown` signal handlers. .. method:: cleanup() :async: A :ref:`coroutine` that should be called on server stopping but after :meth:`shutdown`. The purpose of the method is calling :attr:`on_cleanup` signal handlers. .. note:: Application object has :attr:`router` attribute but has no ``add_route()`` method. The reason is: we want to support different router implementations (even maybe not url-matching based but traversal ones). For sake of that fact we have very trivial ABC for :class:`~aiohttp.abc.AbstractRouter`: it should have only :meth:`aiohttp.abc.AbstractRouter.resolve` coroutine. No methods for adding routes or route reversing (getting URL by route name). All those are router implementation details (but, sure, you need to deal with that methods after choosing the router for your application). .. class:: AppKey(name, t) :canonical: aiohttp.helpers.AppKey This class should be used for the keys in :class:`Application`. They provide a type-safe alternative to `str` keys when checking your code with a type checker (e.g. mypy). They also avoid name clashes with keys from different libraries etc. :param name: A name to help with debugging. This should be the same as the variable name (much like how :class:`typing.TypeVar` is used). :param t: The type that should be used for the value in the dict (e.g. `str`, `Iterator[int]` etc.) .. class:: Server :canonical: aiohttp.web_server.Server A protocol factory compatible with :meth:`~asyncio.AbstractEventLoop.create_server`. The class is responsible for creating HTTP protocol objects that can handle HTTP connections. .. attribute:: connections List of all currently opened connections. .. attribute:: requests_count Amount of processed requests. .. method:: Server.shutdown(timeout) :async: A :ref:`coroutine` that should be called to close all opened connections. .. class:: UrlDispatcher() :canonical: aiohttp.web_urldispatcher.UrlDispatcher For dispatching URLs to :ref:`handlers` :mod:`aiohttp.web` uses *routers*, which is any object that implements :class:`~aiohttp.abc.AbstractRouter` interface. This class is a straightforward url-matching router, implementing :class:`collections.abc.Mapping` for access to *named routes*. :class:`Application` uses this class as :meth:`~aiohttp.web.Application.router` by default. Before running an :class:`Application` you should fill *route table* first by calling :meth:`add_route` and :meth:`add_static`. :ref:`Handler` lookup is performed by iterating on added *routes* in FIFO order. The first matching *route* will be used to call the corresponding *handler*. If during route creation you specify *name* parameter the result is a *named route*. A *named route* can be retrieved by a ``app.router[name]`` call, checking for existence can be done with ``name in app.router`` etc. .. seealso:: :ref:`Route classes ` .. method:: add_resource(path, *, name=None) Append a :term:`resource` to the end of route table. *path* may be either *constant* string like ``'/a/b/c'`` or *variable rule* like ``'/a/{var}'`` (see :ref:`handling variable paths `) :param str path: resource path spec. :param str name: optional resource name. :return: created resource instance (:class:`PlainResource` or :class:`DynamicResource`). .. method:: add_route(method, path, handler, *, \ name=None, expect_handler=None) Append :ref:`handler` to the end of route table. *path* may be either *constant* string like ``'/a/b/c'`` or *variable rule* like ``'/a/{var}'`` (see :ref:`handling variable paths `) Pay attention please: *handler* must be a coroutine. :param str method: HTTP method for route. Should be one of ``'GET'``, ``'POST'``, ``'PUT'``, ``'DELETE'``, ``'PATCH'``, ``'HEAD'``, ``'OPTIONS'`` or ``'*'`` for any method. The parameter is case-insensitive, e.g. you can push ``'get'`` as well as ``'GET'``. :param str path: route path. Should be started with slash (``'/'``). :param collections.abc.Callable handler: route handler. :param str name: optional route name. :param collections.abc.Coroutine expect_handler: optional *expect* header handler. :returns: new :class:`AbstractRoute` instance. .. method:: add_routes(routes_table) Register route definitions from *routes_table*. The table is a :class:`list` of :class:`RouteDef` items or :class:`RouteTableDef`. :returns: :class:`list` of registered :class:`AbstractRoute` instances. .. versionadded:: 2.3 .. versionchanged:: 3.7 Return value updated from ``None`` to :class:`list` of :class:`AbstractRoute` instances. .. method:: add_get(path, handler, *, name=None, allow_head=True, **kwargs) Shortcut for adding a GET handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'GET'``. If *allow_head* is ``True`` (default) the route for method HEAD is added with the same handler as for GET. If *name* is provided the name for HEAD route is suffixed with ``'-head'``. For example ``router.add_get(path, handler, name='route')`` call adds two routes: first for GET with name ``'route'`` and second for HEAD with name ``'route-head'``. .. method:: add_post(path, handler, **kwargs) Shortcut for adding a POST handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'POST'``. .. method:: add_head(path, handler, **kwargs) Shortcut for adding a HEAD handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'HEAD'``. .. method:: add_put(path, handler, **kwargs) Shortcut for adding a PUT handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'PUT'``. .. method:: add_patch(path, handler, **kwargs) Shortcut for adding a PATCH handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'PATCH'``. .. method:: add_delete(path, handler, **kwargs) Shortcut for adding a DELETE handler. Calls the :meth:`add_route` with \ ``method`` equals to ``'DELETE'``. .. method:: add_view(path, handler, **kwargs) Shortcut for adding a class-based view handler. Calls the \ :meth:`add_route` with ``method`` equals to ``'*'``. .. versionadded:: 3.0 .. method:: add_static(prefix, path, *, name=None, expect_handler=None, \ chunk_size=256*1024, \ response_factory=StreamResponse, \ show_index=False, \ follow_symlinks=False, \ append_version=False) Adds a router and a handler for returning static files. Useful for serving static content like images, javascript and css files. On platforms that support it, the handler will transfer files more efficiently using the ``sendfile`` system call. In some situations it might be necessary to avoid using the ``sendfile`` system call even if the platform supports it. This can be accomplished by by setting environment variable ``AIOHTTP_NOSENDFILE=1``. If a Brotli or gzip compressed version of the static content exists at the requested path with the ``.br`` or ``.gz`` extension, it will be used for the response. Brotli will be preferred over gzip if both files exist. .. warning:: Use :meth:`add_static` for development only. In production, static content should be processed by web servers like *nginx* or *apache*. Such web servers will be able to provide significantly better performance and security for static assets. Several past security vulnerabilities in aiohttp only affected applications using :meth:`add_static`. :param str prefix: URL path prefix for handled static files :param path: path to the folder in file system that contains handled static files, :class:`str` or :class:`pathlib.Path`. :param str name: optional route name. :param collections.abc.Coroutine expect_handler: optional *expect* header handler. :param int chunk_size: size of single chunk for file downloading, 256Kb by default. Increasing *chunk_size* parameter to, say, 1Mb may increase file downloading speed but consumes more memory. :param bool show_index: flag for allowing to show indexes of a directory, by default it's not allowed and HTTP/403 will be returned on directory access. :param bool follow_symlinks: flag for allowing to follow symlinks that lead outside the static root directory, by default it's not allowed and HTTP/404 will be returned on access. Enabling ``follow_symlinks`` can be a security risk, and may lead to a directory transversal attack. You do NOT need this option to follow symlinks which point to somewhere else within the static directory, this option is only used to break out of the security sandbox. Enabling this option is highly discouraged, and only expected to be used for edge cases in a local development setting where remote users do not have access to the server. :param bool append_version: flag for adding file version (hash) to the url query string, this value will be used as default when you call to :meth:`~aiohttp.web.AbstractRoute.url` and :meth:`~aiohttp.web.AbstractRoute.url_for` methods. :returns: new :class:`~aiohttp.web.AbstractRoute` instance. .. method:: resolve(request) :async: A :ref:`coroutine` that returns :class:`~aiohttp.abc.AbstractMatchInfo` for *request*. The method never raises exception, but returns :class:`~aiohttp.abc.AbstractMatchInfo` instance with: 1. :attr:`~aiohttp.abc.AbstractMatchInfo.http_exception` assigned to :exc:`HTTPException` instance. 2. :meth:`~aiohttp.abc.AbstractMatchInfo.handler` which raises :exc:`HTTPNotFound` or :exc:`HTTPMethodNotAllowed` on handler's execution if there is no registered route for *request*. *Middlewares* can process that exceptions to render pretty-looking error page for example. Used by internal machinery, end user unlikely need to call the method. .. note:: The method uses :attr:`aiohttp.web.BaseRequest.raw_path` for pattern matching against registered routes. .. method:: resources() The method returns a *view* for *all* registered resources. The view is an object that allows to: 1. Get size of the router table:: len(app.router.resources()) 2. Iterate over registered resources:: for resource in app.router.resources(): print(resource) 3. Make a check if the resources is registered in the router table:: route in app.router.resources() .. method:: routes() The method returns a *view* for *all* registered routes. .. method:: named_resources() Returns a :obj:`dict`-like :class:`types.MappingProxyType` *view* over *all* named **resources**. The view maps every named resource's **name** to the :class:`AbstractResource` instance. It supports the usual :obj:`dict`-like operations, except for any mutable operations (i.e. it's **read-only**):: len(app.router.named_resources()) for name, resource in app.router.named_resources().items(): print(name, resource) "name" in app.router.named_resources() app.router.named_resources()["name"] .. _aiohttp-web-resource: Resource ^^^^^^^^ Default router :class:`UrlDispatcher` operates with :term:`resource`\s. Resource is an item in *routing table* which has a *path*, an optional unique *name* and at least one :term:`route`. :term:`web-handler` lookup is performed in the following way: 1. The router splits the URL and checks the index from longest to shortest. For example, '/one/two/three' will first check the index for '/one/two/three', then '/one/two' and finally '/'. 2. If the URL part is found in the index, the list of routes for that URL part is iterated over. If a route matches to requested HTTP method (or ``'*'`` wildcard) the route's handler is used as the chosen :term:`web-handler`. The lookup is finished. 3. If the route is not found in the index, the router tries to find the route in the list of :class:`~aiohttp.web.MatchedSubAppResource`, (current only created from :meth:`~aiohttp.web.Application.add_domain`), and will iterate over the list of :class:`~aiohttp.web.MatchedSubAppResource` in a linear fashion until a match is found. 4. If no *resource* / *route* pair was found, the *router* returns the special :class:`~aiohttp.abc.AbstractMatchInfo` instance with :attr:`aiohttp.abc.AbstractMatchInfo.http_exception` is not ``None`` but :exc:`HTTPException` with either *HTTP 404 Not Found* or *HTTP 405 Method Not Allowed* status code. Registered :meth:`~aiohttp.abc.AbstractMatchInfo.handler` raises this exception on call. Fixed paths are preferred over variable paths. For example, if you have two routes ``/a/b`` and ``/a/{name}``, then the first route will always be preferred over the second one. If there are multiple dynamic paths with the same fixed prefix, they will be resolved in order of registration. For example, if you have two dynamic routes that are prefixed with the fixed ``/users`` path such as ``/users/{x}/{y}/z`` and ``/users/{x}/y/z``, the first one will be preferred over the second one. User should never instantiate resource classes but give it by :meth:`UrlDispatcher.add_resource` call. After that he may add a :term:`route` by calling :meth:`Resource.add_route`. :meth:`UrlDispatcher.add_route` is just shortcut for:: router.add_resource(path).add_route(method, handler) Resource with a *name* is called *named resource*. The main purpose of *named resource* is constructing URL by route name for passing it into *template engine* for example:: url = app.router['resource_name'].url_for().with_query({'a': 1, 'b': 2}) Resource classes hierarchy:: AbstractResource Resource PlainResource DynamicResource PrefixResource StaticResource PrefixedSubAppResource MatchedSubAppResource .. class:: AbstractResource :canonical: aiohttp.web_urldispatcher.AbstractResource A base class for all resources. Inherited from :class:`collections.abc.Sized` and :class:`collections.abc.Iterable`. ``len(resource)`` returns amount of :term:`route`\s belongs to the resource, ``for route in resource`` allows to iterate over these routes. .. attribute:: name Read-only *name* of resource or ``None``. .. attribute:: canonical Read-only *canonical path* associate with the resource. For example ``/path/to`` or ``/path/{to}`` .. versionadded:: 3.3 .. method:: resolve(request) :async: Resolve resource by finding appropriate :term:`web-handler` for ``(method, path)`` combination. :return: (*match_info*, *allowed_methods*) pair. *allowed_methods* is a :class:`set` or HTTP methods accepted by resource. *match_info* is either :class:`UrlMappingMatchInfo` if request is resolved or ``None`` if no :term:`route` is found. .. method:: get_info() A resource description, e.g. ``{'path': '/path/to'}`` or ``{'formatter': '/path/{to}', 'pattern': re.compile(r'^/path/(?P[a-zA-Z][_a-zA-Z0-9]+)$`` .. method:: url_for(*args, **kwargs) Construct an URL for route with additional params. *args* and **kwargs** depend on a parameters list accepted by inherited resource class. :return: :class:`~yarl.URL` -- resulting URL instance. .. class:: Resource :canonical: aiohttp.web_urldispatcher.Resource A base class for new-style resources, inherits :class:`AbstractResource`. .. method:: add_route(method, handler, *, expect_handler=None) Add a :term:`web-handler` to resource. :param str method: HTTP method for route. Should be one of ``'GET'``, ``'POST'``, ``'PUT'``, ``'DELETE'``, ``'PATCH'``, ``'HEAD'``, ``'OPTIONS'`` or ``'*'`` for any method. The parameter is case-insensitive, e.g. you can push ``'get'`` as well as ``'GET'``. The method should be unique for resource. :param collections.abc.Callable handler: route handler. :param collections.abc.Coroutine expect_handler: optional *expect* header handler. :returns: new :class:`ResourceRoute` instance. .. class:: PlainResource :canonical: aiohttp.web_urldispatcher.PlainResource A resource, inherited from :class:`Resource`. The class corresponds to resources with plain-text matching, ``'/path/to'`` for example. .. attribute:: canonical Read-only *canonical path* associate with the resource. Returns the path used to create the PlainResource. For example ``/path/to`` .. versionadded:: 3.3 .. method:: url_for() Returns a :class:`~yarl.URL` for the resource. .. class:: DynamicResource :canonical: aiohttp.web_urldispatcher.DynamicResource A resource, inherited from :class:`Resource`. The class corresponds to resources with :ref:`variable ` matching, e.g. ``'/path/{to}/{param}'`` etc. .. attribute:: canonical Read-only *canonical path* associate with the resource. Returns the formatter obtained from the path used to create the DynamicResource. For example, from a path ``/get/{num:^\d+}``, it returns ``/get/{num}`` .. versionadded:: 3.3 .. method:: url_for(**params) Returns a :class:`~yarl.URL` for the resource. :param params: -- a variable substitutions for dynamic resource. E.g. for ``'/path/{to}/{param}'`` pattern the method should be called as ``resource.url_for(to='val1', param='val2')`` .. class:: StaticResource :canonical: aiohttp.web_urldispatcher.StaticResource A resource, inherited from :class:`Resource`. The class corresponds to resources for :ref:`static file serving `. .. attribute:: canonical Read-only *canonical path* associate with the resource. Returns the prefix used to create the StaticResource. For example ``/prefix`` .. versionadded:: 3.3 .. method:: url_for(filename, append_version=None) Returns a :class:`~yarl.URL` for file path under resource prefix. :param filename: -- a file name substitution for static file handler. Accepts both :class:`str` and :class:`pathlib.Path`. E.g. an URL for ``'/prefix/dir/file.txt'`` should be generated as ``resource.url_for(filename='dir/file.txt')`` :param bool append_version: -- a flag for adding file version (hash) to the url query string for cache boosting By default has value from a constructor (``False`` by default) When set to ``True`` - ``v=FILE_HASH`` query string param will be added When set to ``False`` has no impact if file not found has no impact .. class:: PrefixedSubAppResource :canonical: aiohttp.web_urldispatcher.PrefixedSubAppResource A resource for serving nested applications. The class instance is returned by :class:`~aiohttp.web.Application.add_subapp` call. .. attribute:: canonical Read-only *canonical path* associate with the resource. Returns the prefix used to create the PrefixedSubAppResource. For example ``/prefix`` .. versionadded:: 3.3 .. method:: url_for(**kwargs) The call is not allowed, it raises :exc:`RuntimeError`. .. _aiohttp-web-route: Route ^^^^^ Route has *HTTP method* (wildcard ``'*'`` is an option), :term:`web-handler` and optional *expect handler*. Every route belong to some resource. Route classes hierarchy:: AbstractRoute ResourceRoute SystemRoute :class:`ResourceRoute` is the route used for resources, :class:`SystemRoute` serves URL resolving errors like *404 Not Found* and *405 Method Not Allowed*. .. class:: AbstractRoute :canonical: aiohttp.web_urldispatcher.AbstractRoute Base class for routes served by :class:`UrlDispatcher`. .. attribute:: method HTTP method handled by the route, e.g. *GET*, *POST* etc. .. attribute:: handler :ref:`handler` that processes the route. .. attribute:: name Name of the route, always equals to name of resource which owns the route. .. attribute:: resource Resource instance which holds the route, ``None`` for :class:`SystemRoute`. .. method:: url_for(*args, **kwargs) Abstract method for constructing url handled by the route. Actually it's a shortcut for ``route.resource.url_for(...)``. .. method:: handle_expect_header(request) :async: ``100-continue`` handler. .. class:: ResourceRoute :canonical: aiohttp.web_urldispatcher.ResourceRoute The route class for handling different HTTP methods for :class:`Resource`. .. class:: SystemRoute :canonical: aiohttp.web_urldispatcher.SystemRoute The route class for handling URL resolution errors like like *404 Not Found* and *405 Method Not Allowed*. .. attribute:: status HTTP status code .. attribute:: reason HTTP status reason .. _aiohttp-web-route-def: RouteDef and StaticDef ^^^^^^^^^^^^^^^^^^^^^^ Route definition, a description for not registered yet route. Could be used for filing route table by providing a list of route definitions (Django style). The definition is created by functions like :func:`get` or :func:`post`, list of definitions could be added to router by :meth:`UrlDispatcher.add_routes` call:: from aiohttp import web async def handle_get(request): ... async def handle_post(request): ... app.router.add_routes([web.get('/get', handle_get), web.post('/post', handle_post), .. class:: AbstractRouteDef :canonical: aiohttp.web_routedef.AbstractRouteDef A base class for route definitions. Inherited from :class:`abc.ABC`. .. versionadded:: 3.1 .. method:: register(router) Register itself into :class:`UrlDispatcher`. Abstract method, should be overridden by subclasses. :returns: :class:`list` of registered :class:`AbstractRoute` objects. .. versionchanged:: 3.7 Return value updated from ``None`` to :class:`list` of :class:`AbstractRoute` instances. .. class:: RouteDef :canonical: aiohttp.web_routedef.RouteDef A definition of not registered yet route. Implements :class:`AbstractRouteDef`. .. versionadded:: 2.3 .. versionchanged:: 3.1 The class implements :class:`AbstractRouteDef` interface. .. attribute:: method HTTP method (``GET``, ``POST`` etc.) (:class:`str`). .. attribute:: path Path to resource, e.g. ``/path/to``. Could contain ``{}`` brackets for :ref:`variable resources ` (:class:`str`). .. attribute:: handler An async function to handle HTTP request. .. attribute:: kwargs A :class:`dict` of additional arguments. .. class:: StaticDef :canonical: aiohttp.web_routedef.StaticDef A definition of static file resource. Implements :class:`AbstractRouteDef`. .. versionadded:: 3.1 .. attribute:: prefix A prefix used for static file handling, e.g. ``/static``. .. attribute:: path File system directory to serve, :class:`str` or :class:`pathlib.Path` (e.g. ``'/home/web-service/path/to/static'``. .. attribute:: kwargs A :class:`dict` of additional arguments, see :meth:`UrlDispatcher.add_static` for a list of supported options. .. function:: get(path, handler, *, name=None, allow_head=True, \ expect_handler=None) :canonical: aiohttp.web_routedef.get Return :class:`RouteDef` for processing ``GET`` requests. See :meth:`UrlDispatcher.add_get` for information about parameters. .. versionadded:: 2.3 .. function:: post(path, handler, *, name=None, expect_handler=None) :canonical: aiohttp.web_routedef.post Return :class:`RouteDef` for processing ``POST`` requests. See :meth:`UrlDispatcher.add_post` for information about parameters. .. versionadded:: 2.3 .. function:: head(path, handler, *, name=None, expect_handler=None) :canonical: aiohttp.web_routedef.head Return :class:`RouteDef` for processing ``HEAD`` requests. See :meth:`UrlDispatcher.add_head` for information about parameters. .. versionadded:: 2.3 .. function:: put(path, handler, *, name=None, expect_handler=None) :canonical: aiohttp.web_routedef.put Return :class:`RouteDef` for processing ``PUT`` requests. See :meth:`UrlDispatcher.add_put` for information about parameters. .. versionadded:: 2.3 .. function:: patch(path, handler, *, name=None, expect_handler=None) :canonical: aiohttp.web_routedef.patch Return :class:`RouteDef` for processing ``PATCH`` requests. See :meth:`UrlDispatcher.add_patch` for information about parameters. .. versionadded:: 2.3 .. function:: delete(path, handler, *, name=None, expect_handler=None) :canonical: aiohttp.web_routedef.delete Return :class:`RouteDef` for processing ``DELETE`` requests. See :meth:`UrlDispatcher.add_delete` for information about parameters. .. versionadded:: 2.3 .. function:: view(path, handler, *, name=None, expect_handler=None) :canonical: aiohttp.web_routedef.view Return :class:`RouteDef` for processing ``ANY`` requests. See :meth:`UrlDispatcher.add_view` for information about parameters. .. versionadded:: 3.0 .. function:: static(prefix, path, *, name=None, expect_handler=None, \ chunk_size=256*1024, \ show_index=False, follow_symlinks=False, \ append_version=False) :canonical: aiohttp.web_routedef.static Return :class:`StaticDef` for processing static files. See :meth:`UrlDispatcher.add_static` for information about supported parameters. .. versionadded:: 3.1 .. function:: route(method, path, handler, *, name=None, expect_handler=None) :canonical: aiohttp.web_routedef.route Return :class:`RouteDef` for processing requests that decided by ``method``. See :meth:`UrlDispatcher.add_route` for information about parameters. .. versionadded:: 2.3 .. _aiohttp-web-route-table-def: RouteTableDef ^^^^^^^^^^^^^ A routes table definition used for describing routes by decorators (Flask style):: from aiohttp import web routes = web.RouteTableDef() @routes.get('/get') async def handle_get(request): ... @routes.post('/post') async def handle_post(request): ... app.router.add_routes(routes) @routes.view("/view") class MyView(web.View): async def get(self): ... async def post(self): ... .. class:: RouteTableDef() :canonical: aiohttp.web_routedef.RouteTableDef A sequence of :class:`RouteDef` instances (implements :class:`collections.abc.Sequence` protocol). In addition to all standard :class:`list` methods the class provides also methods like ``get()`` and ``post()`` for adding new route definition. .. versionadded:: 2.3 .. decoratormethod:: get(path, *, allow_head=True, \ name=None, expect_handler=None) Add a new :class:`RouteDef` item for registering ``GET`` web-handler. See :meth:`UrlDispatcher.add_get` for information about parameters. .. decoratormethod:: post(path, *, name=None, expect_handler=None) Add a new :class:`RouteDef` item for registering ``POST`` web-handler. See :meth:`UrlDispatcher.add_post` for information about parameters. .. decoratormethod:: head(path, *, name=None, expect_handler=None) Add a new :class:`RouteDef` item for registering ``HEAD`` web-handler. See :meth:`UrlDispatcher.add_head` for information about parameters. .. decoratormethod:: put(path, *, name=None, expect_handler=None) Add a new :class:`RouteDef` item for registering ``PUT`` web-handler. See :meth:`UrlDispatcher.add_put` for information about parameters. .. decoratormethod:: patch(path, *, name=None, expect_handler=None) Add a new :class:`RouteDef` item for registering ``PATCH`` web-handler. See :meth:`UrlDispatcher.add_patch` for information about parameters. .. decoratormethod:: delete(path, *, name=None, expect_handler=None) Add a new :class:`RouteDef` item for registering ``DELETE`` web-handler. See :meth:`UrlDispatcher.add_delete` for information about parameters. .. decoratormethod:: view(path, *, name=None, expect_handler=None) Add a new :class:`RouteDef` item for registering ``ANY`` methods against a class-based view. See :meth:`UrlDispatcher.add_view` for information about parameters. .. versionadded:: 3.0 .. method:: static(prefix, path, *, name=None, expect_handler=None, \ chunk_size=256*1024, \ show_index=False, follow_symlinks=False, \ append_version=False) Add a new :class:`StaticDef` item for registering static files processor. See :meth:`UrlDispatcher.add_static` for information about supported parameters. .. versionadded:: 3.1 .. decoratormethod:: route(method, path, *, name=None, expect_handler=None) Add a new :class:`RouteDef` item for registering a web-handler for arbitrary HTTP method. See :meth:`UrlDispatcher.add_route` for information about parameters. MatchInfo ^^^^^^^^^ After route matching web application calls found handler if any. Matching result can be accessible from handler as :attr:`Request.match_info` attribute. In general the result may be any object derived from :class:`~aiohttp.abc.AbstractMatchInfo` (:class:`UrlMappingMatchInfo` for default :class:`UrlDispatcher` router). .. class:: UrlMappingMatchInfo :canonical: aiohttp.web_urldispatcher.UrlMappingMatchInfo Inherited from :class:`dict` and :class:`~aiohttp.abc.AbstractMatchInfo`. Dict items are filled by matching info and is :term:`resource`\-specific. .. attribute:: expect_handler A coroutine for handling ``100-continue``. .. attribute:: handler A coroutine for handling request. .. attribute:: route :class:`AbstractRoute` instance for url matching. View ^^^^ .. class:: View(request) :canonical: aiohttp.web_urldispatcher.View Inherited from :class:`~aiohttp.abc.AbstractView`. Base class for class based views. Implementations should derive from :class:`View` and override methods for handling HTTP verbs like ``get()`` or ``post()``:: class MyView(View): async def get(self): resp = await get_response(self.request) return resp async def post(self): resp = await post_response(self.request) return resp app.router.add_view('/view', MyView) The view raises *405 Method Not allowed* (:class:`HTTPMethodNotAllowed`) if requested web verb is not supported. :param request: instance of :class:`Request` that has initiated a view processing. .. attribute:: request Request sent to view's constructor, read-only property. Overridable coroutine methods: ``connect()``, ``delete()``, ``get()``, ``head()``, ``options()``, ``patch()``, ``post()``, ``put()``, ``trace()``. .. seealso:: :ref:`aiohttp-web-class-based-views` .. _aiohttp-web-app-runners-reference: Running Applications -------------------- To start web application there is ``AppRunner`` and site classes. Runner is a storage for running application, sites are for running application on specific TCP or Unix socket, e.g.:: runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, 'localhost', 8080) await site.start() # wait for finish signal await runner.cleanup() .. versionadded:: 3.0 :class:`AppRunner` / :class:`ServerRunner` and :class:`TCPSite` / :class:`UnixSite` / :class:`SockSite` are added in aiohttp 3.0 .. class:: BaseRunner :canonical: aiohttp.web_runner.BaseRunner A base class for runners. Use :class:`AppRunner` for serving :class:`Application`, :class:`ServerRunner` for low-level :class:`Server`. .. attribute:: server Low-level web :class:`Server` for handling HTTP requests, read-only attribute. .. attribute:: addresses A :class:`list` of served sockets addresses. See :meth:`socket.getsockname() ` for items type. .. versionadded:: 3.3 .. attribute:: sites A read-only :class:`set` of served sites (:class:`TCPSite` / :class:`UnixSite` / :class:`NamedPipeSite` / :class:`SockSite` instances). .. method:: setup() :async: Initialize the server. Should be called before adding sites. .. method:: cleanup() :async: Stop handling all registered sites and cleanup used resources. .. class:: AppRunner(app, *, handle_signals=False, **kwargs) :canonical: aiohttp.web_runner.AppRunner A runner for :class:`Application`. Used with conjunction with sites to serve on specific port. Inherited from :class:`BaseRunner`. :param Application app: web application instance to serve. :param bool handle_signals: add signal handlers for :data:`signal.SIGINT` and :data:`signal.SIGTERM` (``False`` by default). These handlers will raise :exc:`GracefulExit`. :param kwargs: named parameters to pass into web protocol. Supported *kwargs*: :param bool tcp_keepalive: Enable TCP Keep-Alive. Default: ``True``. :param int keepalive_timeout: Number of seconds before closing Keep-Alive connection. Default: ``3630`` seconds (when deployed behind a reverse proxy it's important for this value to be higher than the proxy's timeout. To avoid race conditions we always want the proxy to close the connection). :param logger: Custom logger object. Default: :data:`aiohttp.log.server_logger`. :param access_log: Custom logging object. Default: :data:`aiohttp.log.access_logger`. :param access_log_class: Class for `access_logger`. Default: :data:`aiohttp.helpers.AccessLogger`. Must to be a subclass of :class:`aiohttp.abc.AbstractAccessLogger`. :param str access_log_format: Access log format string. Default: :attr:`helpers.AccessLogger.LOG_FORMAT`. :param int max_line_size: Optional maximum header line size. Default: ``8190``. :param int max_field_size: Optional maximum header combined name and value size. Default: ``8190``. :param int max_headers: Optional maximum number of headers and trailers combined. Default: ``128``. :param float lingering_time: Maximum time during which the server reads and ignores additional data coming from the client when lingering close is on. Use ``0`` to disable lingering on server channel closing. :param int read_bufsize: Size of the read buffer (:attr:`BaseRequest.content`). ``None`` by default, it means that the session global value is used. .. versionadded:: 3.7 :param bool auto_decompress: Automatically decompress request body, ``True`` by default. .. versionadded:: 3.8 .. attribute:: app Read-only attribute for accessing to :class:`Application` served instance. .. method:: setup() :async: Initialize application. Should be called before adding sites. The method calls :attr:`Application.on_startup` registered signals. .. method:: cleanup() :async: Stop handling all registered sites and cleanup used resources. :attr:`Application.on_shutdown` and :attr:`Application.on_cleanup` signals are called internally. .. class:: ServerRunner(web_server, *, handle_signals=False, **kwargs) :canonical: aiohttp.web_runner.ServerRunner A runner for low-level :class:`Server`. Used with conjunction with sites to serve on specific port. Inherited from :class:`BaseRunner`. :param Server web_server: low-level web server instance to serve. :param bool handle_signals: add signal handlers for :data:`signal.SIGINT` and :data:`signal.SIGTERM` (``False`` by default). These handlers will raise :exc:`GracefulExit`. :param kwargs: named parameters to pass into web protocol. .. seealso:: :ref:`aiohttp-web-lowlevel` demonstrates low-level server usage .. class:: BaseSite :canonical: aiohttp.web_runner.BaseSite An abstract class for handled sites. .. attribute:: name An identifier for site, read-only :class:`str` property. Could be a handled URL or UNIX socket path. .. method:: start() :async: Start handling a site. .. method:: stop() :async: Stop handling a site. .. class:: TCPSite(runner, host=None, port=None, *, \ shutdown_timeout=60.0, ssl_context=None, \ backlog=128, reuse_address=None, \ reuse_port=None) :canonical: aiohttp.web_runner.TCPSite Serve a runner on TCP socket. :param runner: a runner to serve. :param str host: HOST to listen on, all interfaces if ``None`` (default). :param int port: PORT to listen on, ``8080`` if ``None`` (default). Use ``0`` to let the OS assign a free ephemeral port (see :attr:`port`). :param float shutdown_timeout: a timeout used for both waiting on pending tasks before application shutdown and for closing opened connections on :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP server (default). :param int backlog: a number of unaccepted connections that the system will allow before refusing new connections, see :meth:`socket.socket.listen` for details. ``128`` by default. :param bool reuse_address: tells the kernel to reuse a local socket in TIME_WAIT state, without waiting for its natural timeout to expire. If not specified will automatically be set to True on UNIX. :param bool reuse_port: tells the kernel to allow this endpoint to be bound to the same port as other existing endpoints are bound to, so long as they all set this flag when being created. This option is not supported on Windows. .. attribute:: port Read-only. The actual port number the server is bound to, only guaranteed to be correct after the site has been started. .. class:: UnixSite(runner, path, *, \ shutdown_timeout=60.0, ssl_context=None, \ backlog=128) :canonical: aiohttp.web_runner.UnixSite Serve a runner on UNIX socket. :param runner: a runner to serve. :param str path: PATH to UNIX socket to listen. :param float shutdown_timeout: a timeout used for both waiting on pending tasks before application shutdown and for closing opened connections on :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP server (default). :param int backlog: a number of unaccepted connections that the system will allow before refusing new connections, see :meth:`socket.socket.listen` for details. ``128`` by default. .. class:: NamedPipeSite(runner, path, *, shutdown_timeout=60.0) :canonical: aiohttp.web_runner.NamedPipeSite Serve a runner on Named Pipe in Windows. :param runner: a runner to serve. :param str path: PATH of named pipe to listen. :param float shutdown_timeout: a timeout used for both waiting on pending tasks before application shutdown and for closing opened connections on :meth:`BaseSite.stop` call. .. class:: SockSite(runner, sock, *, \ shutdown_timeout=60.0, ssl_context=None, \ backlog=128) :canonical: aiohttp.web_runner.SockSite Serve a runner on UNIX socket. :param runner: a runner to serve. :param sock: A :ref:`socket instance ` to listen to. :param float shutdown_timeout: a timeout used for both waiting on pending tasks before application shutdown and for closing opened connections on :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP server (default). :param int backlog: a number of unaccepted connections that the system will allow before refusing new connections, see :meth:`socket.socket.listen` for details. ``128`` by default. .. exception:: GracefulExit :canonical: aiohttp.web_runner.GracefulExit Raised by signal handlers for :data:`signal.SIGINT` and :data:`signal.SIGTERM` defined in :class:`AppRunner` and :class:`ServerRunner` when ``handle_signals`` is set to ``True``. Inherited from :exc:`SystemExit`, which exits with error code ``1`` if not handled. Utilities --------- .. class:: FileField :canonical: aiohttp.web_request.FileField A :mod:`dataclass ` instance that is returned as multidict value by :meth:`aiohttp.web.BaseRequest.post` if field is uploaded file. .. attribute:: name Field name .. attribute:: filename File name as specified by uploading (client) side. .. attribute:: file An :class:`io.IOBase` instance with content of uploaded file. .. attribute:: content_type *MIME type* of uploaded file, ``'text/plain'`` by default. .. seealso:: :ref:`aiohttp-web-file-upload` .. function:: run_app(app, *, debug=False, host=None, port=None, \ path=None, sock=None, shutdown_timeout=60.0, \ keepalive_timeout=3630, ssl_context=None, \ print=print, backlog=128, \ access_log_class=aiohttp.helpers.AccessLogger, \ access_log_format=aiohttp.helpers.AccessLogger.LOG_FORMAT, \ access_log=aiohttp.log.access_logger, \ handle_signals=True, \ reuse_address=None, \ reuse_port=None, \ handler_cancellation=False, \ **kwargs) A high-level function for running an application, serving it until keyboard interrupt and performing a :ref:`aiohttp-web-graceful-shutdown`. This is a high-level function very similar to :func:`asyncio.run` and should be used as the main entry point for an application. The :class:`Application` object essentially becomes our `main()` function. If additional tasks need to be run in parallel, see :ref:`aiohttp-web-complex-applications`. The server will listen on any host or Unix domain socket path you supply. If no hosts or paths are supplied, or only a port is supplied, a TCP server listening on 0.0.0.0 (all hosts) will be launched. Distributing HTTP traffic to multiple hosts or paths on the same application process provides no performance benefit as the requests are handled on the same event loop. See :doc:`deployment` for ways of distributing work for increased performance. :param app: :class:`Application` instance to run or a *coroutine* that returns an application. :param bool debug: enable :ref:`asyncio debug mode ` if ``True``. :param str host: TCP/IP host or a sequence of hosts for HTTP server. Default is ``'0.0.0.0'`` if *port* has been specified or if *path* is not supplied. :param int port: TCP/IP port for HTTP server. Default is ``8080`` for plain text HTTP and ``8443`` for HTTP via SSL (when *ssl_context* parameter is specified). :param path: file system path for HTTP server Unix domain socket. A sequence of file system paths can be used to bind multiple domain sockets. Listening on Unix domain sockets is not supported by all operating systems, :class:`str`, :class:`pathlib.Path` or an iterable of these. :param socket.socket sock: a preexisting socket object to accept connections on. A sequence of socket objects can be passed. :param int shutdown_timeout: a delay to wait for graceful server shutdown before disconnecting all open client sockets hard way. This is used as a delay to wait for pending tasks to complete and then again to close any pending connections. A system with properly :ref:`aiohttp-web-graceful-shutdown` implemented never waits for the second timeout but closes a server in a few milliseconds. :param float keepalive_timeout: a delay before a TCP connection is closed after a HTTP request. The delay allows for reuse of a TCP connection. When deployed behind a reverse proxy it's important for this value to be higher than the proxy's timeout. To avoid race conditions, we always want the proxy to handle connection closing. .. versionadded:: 3.8 :param ssl_context: :class:`ssl.SSLContext` for HTTPS server, ``None`` for HTTP connection. :param print: a callable compatible with :func:`print`. May be used to override STDOUT output or suppress it. Passing `None` disables output. :param int backlog: the number of unaccepted connections that the system will allow before refusing new connections (``128`` by default). :param access_log_class: class for `access_logger`. Default: :data:`aiohttp.helpers.AccessLogger`. Must to be a subclass of :class:`aiohttp.abc.AbstractAccessLogger`. :param access_log: :class:`logging.Logger` instance used for saving access logs. Use ``None`` for disabling logs for sake of speedup. :param access_log_format: access log format, see :ref:`aiohttp-logging-access-log-format-spec` for details. :param bool handle_signals: override signal TERM handling to gracefully exit the application. :param bool reuse_address: tells the kernel to reuse a local socket in TIME_WAIT state, without waiting for its natural timeout to expire. If not specified will automatically be set to True on UNIX. :param bool reuse_port: tells the kernel to allow this endpoint to be bound to the same port as other existing endpoints are bound to, so long as they all set this flag when being created. This option is not supported on Windows. :param bool handler_cancellation: cancels the web handler task if the client drops the connection. This is recommended if familiar with asyncio behavior or scalability is a concern. :ref:`aiohttp-web-peer-disconnection` :param kwargs: additional named parameters to pass into :class:`AppRunner` constructor. .. versionadded:: 3.0 Support *access_log_class* parameter. Support *reuse_address*, *reuse_port* parameter. .. versionadded:: 3.1 Accept a coroutine as *app* parameter. .. versionadded:: 3.9 Support handler_cancellation parameter (this was the default behavior in aiohttp <3.7). Constants --------- .. class:: ContentCoding :canonical: aiohttp.web_response.ContentCoding An :class:`enum.Enum` class of available Content Codings. .. attribute:: deflate *DEFLATE compression* .. attribute:: gzip *GZIP compression* .. attribute:: identity *no compression* Middlewares ----------- .. function:: normalize_path_middleware(*, \ append_slash=True, \ remove_slash=False, \ merge_slashes=True, \ redirect_class=HTTPPermanentRedirect) :canonical: aiohttp.web_middlewares.normalize_path_middleware Middleware factory which produces a middleware that normalizes the path of a request. By normalizing it means: - Add or remove a trailing slash to the path. - Double slashes are replaced by one. The middleware returns as soon as it finds a path that resolves correctly. The order if both merge and append/remove are enabled is: 1. *merge_slashes* 2. *append_slash* or *remove_slash* 3. both *merge_slashes* and *append_slash* or *remove_slash* If the path resolves with at least one of those conditions, it will redirect to the new path. Only one of *append_slash* and *remove_slash* can be enabled. If both are ``True`` the factory will raise an ``AssertionError`` If *append_slash* is ``True`` the middleware will append a slash when needed. If a resource is defined with trailing slash and the request comes without it, it will append it automatically. If *remove_slash* is ``True``, *append_slash* must be ``False``. When enabled the middleware will remove trailing slashes and redirect if the resource is defined. If *merge_slashes* is ``True``, merge multiple consecutive slashes in the path into one. .. versionadded:: 3.4 Support for *remove_slash* ================================================ FILE: docs/websocket_utilities.rst ================================================ .. currentmodule:: aiohttp WebSocket utilities =================== .. class:: WSCloseCode :canonical: aiohttp._websocket.models.WSCloseCode An :class:`~enum.IntEnum` for keeping close message code. .. attribute:: OK A normal closure, meaning that the purpose for which the connection was established has been fulfilled. .. attribute:: GOING_AWAY An endpoint is "going away", such as a server going down or a browser having navigated away from a page. .. attribute:: PROTOCOL_ERROR An endpoint is terminating the connection due to a protocol error. .. attribute:: UNSUPPORTED_DATA An endpoint is terminating the connection because it has received a type of data it cannot accept (e.g., an endpoint that understands only text data MAY send this if it receives a binary message). .. attribute:: INVALID_TEXT An endpoint is terminating the connection because it has received data within a message that was not consistent with the type of the message (e.g., non-UTF-8 :rfc:`3629` data within a text message). .. attribute:: POLICY_VIOLATION An endpoint is terminating the connection because it has received a message that violates its policy. This is a generic status code that can be returned when there is no other more suitable status code (e.g., :attr:`~aiohttp.WSCloseCode.UNSUPPORTED_DATA` or :attr:`~aiohttp.WSCloseCode.MESSAGE_TOO_BIG`) or if there is a need to hide specific details about the policy. .. attribute:: MESSAGE_TOO_BIG An endpoint is terminating the connection because it has received a message that is too big for it to process. .. attribute:: MANDATORY_EXTENSION An endpoint (client) is terminating the connection because it has expected the server to negotiate one or more extension, but the server did not return them in the response message of the WebSocket handshake. The list of extensions that are needed should appear in the /reason/ part of the Close frame. Note that this status code is not used by the server, because it can fail the WebSocket handshake instead. .. attribute:: INTERNAL_ERROR A server is terminating the connection because it encountered an unexpected condition that prevented it from fulfilling the request. .. attribute:: SERVICE_RESTART The service is restarted. a client may reconnect, and if it chooses to do, should reconnect using a randomized delay of 5-30s. .. attribute:: TRY_AGAIN_LATER The service is experiencing overload. A client should only connect to a different IP (when there are multiple for the target) or reconnect to the same IP upon user action. .. attribute:: ABNORMAL_CLOSURE Used to indicate that a connection was closed abnormally (that is, with no close frame being sent) when a status code is expected. .. attribute:: BAD_GATEWAY The server was acting as a gateway or proxy and received an invalid response from the upstream server. This is similar to 502 HTTP Status Code. .. class:: WSMsgType :canonical: aiohttp._websocket.models.WSMsgType An :class:`~enum.IntEnum` for describing :class:`WSMessage` type. .. attribute:: CONTINUATION A mark for continuation frame, user will never get the message with this type. .. attribute:: TEXT Text message, the value has :class:`str` type. .. attribute:: BINARY Binary message, the value has :class:`bytes` type. .. attribute:: PING Ping frame (sent by client peer). .. attribute:: PONG Pong frame, answer on ping. Sent by server peer. .. attribute:: CLOSE Close frame. .. attribute:: CLOSED FRAME Actually not frame but a flag indicating that websocket was closed. .. attribute:: ERROR Actually not frame but a flag indicating that websocket was received an error. .. class:: WSMessage Websocket message, returned by ``.receive()`` calls. This is actually defined as a :class:`typing.Union` of different message types. All messages are a :func:`collections.namedtuple` with the below attributes. .. attribute:: data Message payload. 1. :class:`str` for :attr:`WSMsgType.TEXT` messages. 2. :class:`bytes` for :attr:`WSMsgType.BINARY` messages. 3. :class:`int` (see :class:`WSCloseCode` for common codes) for :attr:`WSMsgType.CLOSE` messages. 4. :class:`bytes` for :attr:`WSMsgType.PING` messages. 5. :class:`bytes` for :attr:`WSMsgType.PONG` messages. 6. :class:`Exception` for :attr:`WSMsgType.ERROR` messages. .. attribute:: extra Additional info, :class:`str` if provided, otherwise defaults to ``None``. Makes sense only for :attr:`WSMsgType.CLOSE` messages, contains optional message description. .. attribute:: type Message type, :class:`WSMsgType` instance. .. method:: json(*, loads=json.loads) Returns parsed JSON data (the method is only present on :attr:`WSMsgType.TEXT` and :attr:`WSMsgType.BINARY` messages). :param loads: optional JSON decoder function. ================================================ FILE: docs/whats_new_1_1.rst ================================================ ========================= What's new in aiohttp 1.1 ========================= YARL and URL encoding ====================== Since aiohttp 1.1 the library uses :term:`yarl` for URL processing. New API ------- :class:`yarl.URL` gives handy methods for URL operations etc. Client API still accepts :class:`str` everywhere *url* is used, e.g. ``session.get('http://example.com')`` works as well as ``session.get(yarl.URL('http://example.com'))``. Internal API has been switched to :class:`yarl.URL`. :class:`aiohttp.CookieJar` accepts :class:`~yarl.URL` instances only. On server side has added :attr:`aiohttp.web.BaseRequest.url` and :attr:`aiohttp.web.BaseRequest.rel_url` properties for representing relative and absolute request's URL. URL using is the recommended way, already existed properties for retrieving URL parts are deprecated and will be eventually removed. Redirection web exceptions accepts :class:`yarl.URL` as *location* parameter. :class:`str` is still supported and will be supported forever. Reverse URL processing for *router* has been changed. The main API is ``aiohttp.web.Request.url_for`` which returns a :class:`yarl.URL` instance for named resource. It does not support *query args* but adding *args* is trivial: ``request.url_for('named_resource', param='a').with_query(arg='val')``. The method returns a *relative* URL, absolute URL may be constructed by ``request.url.join(request.url_for(...)`` call. URL encoding ------------ YARL encodes all non-ASCII symbols on :class:`yarl.URL` creation. Thus ``URL('https://www.python.org/путь')`` becomes ``'https://www.python.org/%D0%BF%D1%83%D1%82%D1%8C'``. On filling route table it's possible to use both non-ASCII and percent encoded paths:: app.router.add_get('/путь', handler) and:: app.router.add_get('/%D0%BF%D1%83%D1%82%D1%8C', handler) are the same. Internally ``'/путь'`` is converted into percent-encoding representation. Route matching also accepts both URL forms: raw and encoded by converting the route pattern to *canonical* (encoded) form on route registration. Sub-Applications ================ Sub applications are designed for solving the problem of the big monolithic code base. Let's assume we have a project with own business logic and tools like administration panel and debug toolbar. Administration panel is a separate application by its own nature but all toolbar URLs are served by prefix like ``/admin``. Thus we'll create a totally separate application named ``admin`` and connect it to main app with prefix:: admin = web.Application() # setup admin routes, signals and middlewares app.add_subapp('/admin/', admin) Middlewares and signals from ``app`` and ``admin`` are chained. It means that if URL is ``'/admin/something'`` middlewares from ``app`` are applied first and ``admin.middlewares`` are the next in the call chain. The same is going for :attr:`~aiohttp.web.Application.on_response_prepare` signal -- the signal is delivered to both top level ``app`` and ``admin`` if processing URL is routed to ``admin`` sub-application. Common signals like :attr:`~aiohttp.web.Application.on_startup`, :attr:`~aiohttp.web.Application.on_shutdown` and :attr:`~aiohttp.web.Application.on_cleanup` are delivered to all registered sub-applications. The passed parameter is sub-application instance, not top-level application. Third level sub-applications can be nested into second level ones -- there are no limitation for nesting level. Url reversing ------------- Url reversing for sub-applications should generate urls with proper prefix. But for getting URL sub-application's router should be used:: admin = web.Application() admin.add_get('/resource', handler, name='name') app.add_subapp('/admin/', admin) url = admin.router['name'].url_for() The generated ``url`` from example will have a value ``URL('/admin/resource')``. Application freezing ==================== Application can be used either as main app (``app.make_handler()``) or as sub-application -- not both cases at the same time. After connecting application by ``.add_subapp()`` call or starting serving web-server as toplevel application the application is **frozen**. It means that registering new routes, signals and middlewares is forbidden. Changing state (``app['name'] = 'value'``) of frozen application is deprecated and will be eventually removed. ================================================ FILE: docs/whats_new_3_0.rst ================================================ .. _aiohttp_whats_new_3_0: ========================= What's new in aiohttp 3.0 ========================= async/await everywhere ====================== The main change is dropping ``yield from`` support and using ``async``/``await`` everywhere. Farewell, Python 3.4. The minimal supported Python version is **3.5.3** now. Why not *3.5.0*? Because *3.5.3* has a crucial change: :func:`asyncio.get_event_loop()` returns the running loop instead of *default*, which may be different, e.g.:: loop = asyncio.new_event_loop() loop.run_until_complete(f()) Note, :func:`asyncio.set_event_loop` was not called and default loop is not equal to actually executed one. Application Runners =================== People constantly asked about ability to run aiohttp servers together with other asyncio code, but :func:`aiohttp.web.run_app` is blocking synchronous call. aiohttp had support for starting the application without ``run_app`` but the API was very low-level and cumbersome. Now application runners solve the task in a few lines of code, see :ref:`aiohttp-web-app-runners` for details. Client Tracing ============== Other long awaited feature is tracing client request life cycle to figure out when and why client request spends a time waiting for connection establishment, getting server response headers etc. Now it is possible by registering special signal handlers on every request processing stage. :ref:`aiohttp-client-tracing` provides more info about the feature. HTTPS support ============= Unfortunately asyncio has a bug with checking SSL certificates for non-ASCII site DNS names, e.g. `https://историк.рф `_ or `https://雜草工作室.香港 `_. The bug has been fixed in upcoming Python 3.7 only (the change requires breaking backward compatibility in :mod:`ssl` API). aiohttp installs a fix for older Python versions (3.5 and 3.6). Dropped obsolete API ==================== A switch to new major version is a great chance for dropping already deprecated features. The release dropped a lot, see :ref:`aiohttp_changes` for details. All removals was already marked as deprecated or related to very low level implementation details. If user code did not raise :exc:`DeprecationWarning` it is compatible with aiohttp 3.0 most likely. Summary ======= Enjoy aiohttp 3.0 release! The full change log is here: :ref:`aiohttp_changes`. ================================================ FILE: examples/background_tasks.py ================================================ #!/usr/bin/env python3 """Example of aiohttp.web.Application.on_startup signal handler""" import asyncio from collections.abc import AsyncIterator from contextlib import asynccontextmanager, suppress import valkey.asyncio as valkey from aiohttp import web valkey_listener = web.AppKey("valkey_listener", asyncio.Task[None]) websockets = web.AppKey("websockets", list[web.WebSocketResponse]) async def websocket_handler(request: web.Request) -> web.StreamResponse: ws = web.WebSocketResponse() await ws.prepare(request) request.app[websockets].append(ws) try: async for msg in ws: print(msg) await asyncio.sleep(1) finally: request.app[websockets].remove(ws) return ws async def on_shutdown(app: web.Application) -> None: for ws in app[websockets]: await ws.close(code=999, message=b"Server shutdown") async def listen_to_valkey(app: web.Application) -> None: r = valkey.Valkey(host="localhost", port=6379, decode_responses=True) channel = "news" async with r.pubsub() as sub: await sub.subscribe(channel) async for msg in sub.listen(): if msg["type"] != "message": continue # Forward message to all connected websockets: for ws in app[websockets]: await ws.send_str(f"{channel}: {msg}") print(f"message in {channel}: {msg}") @asynccontextmanager async def background_tasks(app: web.Application) -> AsyncIterator[None]: app[valkey_listener] = asyncio.create_task(listen_to_valkey(app)) yield print("cleanup background tasks...") app[valkey_listener].cancel() with suppress(asyncio.CancelledError): await app[valkey_listener] def init() -> web.Application: app = web.Application() l: list[web.WebSocketResponse] = [] app[websockets] = l app.router.add_get("/news", websocket_handler) app.cleanup_ctx.append(background_tasks) app.on_shutdown.append(on_shutdown) return app web.run_app(init()) ================================================ FILE: examples/basic_auth_middleware.py ================================================ #!/usr/bin/env python3 """ Example of using basic authentication middleware with aiohttp client. This example shows how to implement a middleware that automatically adds Basic Authentication headers to all requests. The middleware encodes the username and password in base64 format as required by the HTTP Basic Auth specification. This example includes a test server that validates basic auth credentials. """ import asyncio import base64 import binascii import logging from aiohttp import ( ClientHandlerType, ClientRequest, ClientResponse, ClientSession, hdrs, web, ) logging.basicConfig(level=logging.DEBUG) _LOGGER = logging.getLogger(__name__) class BasicAuthMiddleware: """Middleware that adds Basic Authentication to all requests.""" def __init__(self, username: str, password: str) -> None: self.username = username self.password = password self._auth_header = self._encode_credentials() def _encode_credentials(self) -> str: """Encode username and password to base64.""" credentials = f"{self.username}:{self.password}" encoded = base64.b64encode(credentials.encode()).decode() return f"Basic {encoded}" async def __call__( self, request: ClientRequest, handler: ClientHandlerType, ) -> ClientResponse: """Add Basic Auth header to the request.""" # Only add auth if not already present if hdrs.AUTHORIZATION not in request.headers: request.headers[hdrs.AUTHORIZATION] = self._auth_header # Proceed with the request return await handler(request) class TestServer: """Test server for basic auth endpoints.""" async def handle_basic_auth(self, request: web.Request) -> web.Response: """Handle basic auth validation.""" # Get expected credentials from path expected_user = request.match_info["user"] expected_pass = request.match_info["pass"] # Check if Authorization header is present auth_header = request.headers.get(hdrs.AUTHORIZATION, "") if not auth_header.startswith("Basic "): return web.Response( status=401, text="Unauthorized", headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, ) # Decode the credentials encoded_creds = auth_header[6:] # Remove "Basic " try: decoded = base64.b64decode(encoded_creds).decode() username, password = decoded.split(":", 1) except (ValueError, binascii.Error): return web.Response( status=401, text="Invalid credentials format", headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, ) # Validate credentials if username != expected_user or password != expected_pass: return web.Response( status=401, text="Invalid username or password", headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, ) return web.json_response({"authenticated": True, "user": username}) async def handle_protected_resource(self, request: web.Request) -> web.Response: """A protected resource that requires any valid auth.""" auth_header = request.headers.get(hdrs.AUTHORIZATION, "") if not auth_header.startswith("Basic "): return web.Response( status=401, text="Authentication required", headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="protected"'}, ) return web.json_response( { "message": "Access granted to protected resource", "auth_provided": True, } ) async def run_test_server() -> web.AppRunner: """Run a simple test server with basic auth endpoints.""" app = web.Application() server = TestServer() app.router.add_get("/basic-auth/{user}/{pass}", server.handle_basic_auth) app.router.add_get("/protected", server.handle_protected_resource) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "localhost", 8080) await site.start() return runner async def run_tests() -> None: """Run all basic auth middleware tests.""" # Create middleware instance auth_middleware = BasicAuthMiddleware("user", "pass") # Use middleware in session async with ClientSession(middlewares=(auth_middleware,)) as session: # Test 1: Correct credentials endpoint print("=== Test 1: Correct credentials ===") async with session.get("http://localhost:8080/basic-auth/user/pass") as resp: _LOGGER.info("Status: %s", resp.status) if resp.status == 200: data = await resp.json() _LOGGER.info("Response: %s", data) print("Authentication successful!") print(f"Authenticated: {data.get('authenticated')}") print(f"User: {data.get('user')}") else: print("Authentication failed!") print(f"Status: {resp.status}") text = await resp.text() print(f"Response: {text}") # Test 2: Wrong credentials endpoint print("\n=== Test 2: Wrong credentials endpoint ===") async with session.get("http://localhost:8080/basic-auth/other/secret") as resp: if resp.status == 401: print("Authentication failed as expected (wrong credentials)") text = await resp.text() print(f"Response: {text}") else: print(f"Unexpected status: {resp.status}") # Test 3: Protected resource print("\n=== Test 3: Access protected resource ===") async with session.get("http://localhost:8080/protected") as resp: if resp.status == 200: data = await resp.json() print("Successfully accessed protected resource!") print(f"Response: {data}") else: print(f"Failed to access protected resource: {resp.status}") async def main() -> None: # Start test server server = await run_test_server() try: await run_tests() finally: await server.cleanup() if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/cli_app.py ================================================ #!/usr/bin/env python3 """ Example of serving an Application using the `aiohttp.web` CLI. Serve this app using:: $ python -m aiohttp.web -H localhost -P 8080 --repeat 10 cli_app:init \ > "Hello World" Here ``--repeat`` & ``"Hello World"`` are application specific command-line arguments. `aiohttp.web` only parses & consumes the command-line arguments it needs (i.e. ``-H``, ``-P`` & ``entry-func``) and passes on any additional arguments to the `cli_app:init` function for processing. """ from argparse import ArgumentParser, Namespace from collections.abc import Sequence from aiohttp import web args_key = web.AppKey("args_key", Namespace) async def display_message(req: web.Request) -> web.StreamResponse: args = req.app[args_key] text = "\n".join([args.message] * args.repeat) return web.Response(text=text) def init(argv: Sequence[str] | None) -> web.Application: arg_parser = ArgumentParser( prog="aiohttp.web ...", description="Application CLI", add_help=False ) # Positional argument arg_parser.add_argument("message", help="message to print") # Optional argument arg_parser.add_argument( "--repeat", help="number of times to repeat message", type=int, default="1" ) # Avoid conflict with -h from `aiohttp.web` CLI parser arg_parser.add_argument( "--app-help", help="show this message and exit", action="help" ) args = arg_parser.parse_args(argv) app = web.Application() app[args_key] = args app.router.add_get("/", display_message) return app ================================================ FILE: examples/client_auth.py ================================================ #!/usr/bin/env python3 import asyncio import aiohttp async def fetch(session: aiohttp.ClientSession) -> None: print("Query http://httpbin.org/basic-auth/andrew/password") async with session.get("http://httpbin.org/basic-auth/andrew/password") as resp: print(resp.status) body = await resp.text() print(body) async def go() -> None: async with aiohttp.ClientSession( auth=aiohttp.BasicAuth("andrew", "password") ) as session: await fetch(session) loop = asyncio.get_event_loop() loop.run_until_complete(go()) ================================================ FILE: examples/client_json.py ================================================ #!/usr/bin/env python3 import asyncio import aiohttp async def fetch(session: aiohttp.ClientSession) -> None: print("Query http://httpbin.org/get") async with session.get("http://httpbin.org/get") as resp: print(resp.status) data = await resp.json() print(data) async def go() -> None: async with aiohttp.ClientSession() as session: await fetch(session) loop = asyncio.get_event_loop() loop.run_until_complete(go()) loop.close() ================================================ FILE: examples/client_ws.py ================================================ #!/usr/bin/env python3 """websocket cmd client for web_ws.py example.""" import argparse import asyncio import sys from contextlib import suppress import aiohttp async def start_client(url: str) -> None: name = input("Please enter your name: ") async def dispatch(ws: aiohttp.ClientWebSocketResponse) -> None: while True: msg = await ws.receive() if msg.type is aiohttp.WSMsgType.TEXT: print("Text: ", msg.data.strip()) elif msg.type is aiohttp.WSMsgType.BINARY: print("Binary: ", msg.data) elif msg.type is aiohttp.WSMsgType.PING: await ws.pong() elif msg.type is aiohttp.WSMsgType.PONG: print("Pong received") else: if msg.type is aiohttp.WSMsgType.CLOSE: await ws.close() elif msg.type is aiohttp.WSMsgType.ERROR: print("Error during receive %s" % ws.exception()) elif msg.type is aiohttp.WSMsgType.CLOSED: pass break async with aiohttp.ClientSession() as session: async with session.ws_connect(url, autoclose=False, autoping=False) as ws: # send request dispatch_task = asyncio.create_task(dispatch(ws)) # Exit with Ctrl+D while line := await asyncio.to_thread(sys.stdin.readline): await ws.send_str(name + ": " + line) dispatch_task.cancel() with suppress(asyncio.CancelledError): await dispatch_task ARGS = argparse.ArgumentParser( description="websocket console client for wssrv.py example." ) ARGS.add_argument( "--host", action="store", dest="host", default="127.0.0.1", help="Host name" ) ARGS.add_argument( "--port", action="store", dest="port", default=8080, type=int, help="Port number" ) if __name__ == "__main__": args = ARGS.parse_args() if ":" in args.host: args.host, port = args.host.split(":", 1) args.port = int(port) url = f"http://{args.host}:{args.port}" asyncio.run(start_client(url)) ================================================ FILE: examples/combined_middleware.py ================================================ #!/usr/bin/env python3 """ Example of combining multiple middleware with aiohttp client. This example shows how to chain multiple middleware together to create a powerful request pipeline. Middleware are applied in order, demonstrating how logging, authentication, and retry logic can work together. The order of middleware matters: 1. Logging (outermost) - logs all attempts including retries 2. Authentication - adds auth headers before retry logic 3. Retry (innermost) - retries requests on failure """ import asyncio import base64 import binascii import logging import time from http import HTTPStatus from typing import TYPE_CHECKING from aiohttp import ( ClientHandlerType, ClientRequest, ClientResponse, ClientSession, hdrs, web, ) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) _LOGGER = logging.getLogger(__name__) class LoggingMiddleware: """Middleware that logs request timing and response status.""" async def __call__( self, request: ClientRequest, handler: ClientHandlerType, ) -> ClientResponse: start_time = time.monotonic() # Log request _LOGGER.info("[REQUEST] %s %s", request.method, request.url) # Execute request response = await handler(request) # Log response duration = time.monotonic() - start_time _LOGGER.info( "[RESPONSE] %s in %.2fs - Status: %s", request.url.path, duration, response.status, ) return response class BasicAuthMiddleware: """Middleware that adds Basic Authentication to all requests.""" def __init__(self, username: str, password: str) -> None: self.username = username self.password = password self._auth_header = self._encode_credentials() def _encode_credentials(self) -> str: """Encode username and password to base64.""" credentials = f"{self.username}:{self.password}" encoded = base64.b64encode(credentials.encode()).decode() return f"Basic {encoded}" async def __call__( self, request: ClientRequest, handler: ClientHandlerType, ) -> ClientResponse: """Add Basic Auth header to the request.""" # Only add auth if not already present if hdrs.AUTHORIZATION not in request.headers: request.headers[hdrs.AUTHORIZATION] = self._auth_header _LOGGER.debug("Added Basic Auth header") # Proceed with the request return await handler(request) DEFAULT_RETRY_STATUSES: set[HTTPStatus] = { HTTPStatus.TOO_MANY_REQUESTS, HTTPStatus.INTERNAL_SERVER_ERROR, HTTPStatus.BAD_GATEWAY, HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.GATEWAY_TIMEOUT, } class RetryMiddleware: """Middleware that retries failed requests with exponential backoff.""" def __init__( self, max_retries: int = 3, retry_statuses: set[HTTPStatus] | None = None, initial_delay: float = 1.0, backoff_factor: float = 2.0, ) -> None: self.max_retries = max_retries self.retry_statuses = retry_statuses or DEFAULT_RETRY_STATUSES self.initial_delay = initial_delay self.backoff_factor = backoff_factor async def __call__( self, request: ClientRequest, handler: ClientHandlerType, ) -> ClientResponse: """Execute request with retry logic.""" last_response: ClientResponse | None = None delay = self.initial_delay for attempt in range(self.max_retries + 1): if attempt > 0: _LOGGER.info( "Retrying request (attempt %s/%s)", attempt + 1, self.max_retries + 1, ) # Execute the request response = await handler(request) last_response = response # Check if we should retry if response.status not in self.retry_statuses: return response # Don't retry if we've exhausted attempts if attempt >= self.max_retries: _LOGGER.warning("Max retries exceeded") return response # Wait before retrying _LOGGER.debug("Waiting %ss before retry...", delay) await asyncio.sleep(delay) delay *= self.backoff_factor if TYPE_CHECKING: assert last_response is not None # Always set since we loop at least once return last_response class TestServer: """Test server with stateful endpoints for middleware testing.""" def __init__(self) -> None: self.flaky_counter = 0 self.protected_counter = 0 async def handle_protected(self, request: web.Request) -> web.Response: """Protected endpoint that requires authentication and is flaky on first attempt.""" auth_header = request.headers.get(hdrs.AUTHORIZATION, "") if not auth_header.startswith("Basic "): return web.Response( status=401, text="Unauthorized", headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, ) # Decode the credentials encoded_creds = auth_header[6:] # Remove "Basic " try: decoded = base64.b64decode(encoded_creds).decode() username, password = decoded.split(":", 1) except (ValueError, binascii.Error): return web.Response( status=401, text="Invalid credentials format", headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'}, ) # Validate credentials if username != "user" or password != "pass": return web.Response(status=401, text="Invalid credentials") # Fail with 500 on first attempt to test retry + auth combination self.protected_counter += 1 if self.protected_counter == 1: return web.Response( status=500, text="Internal server error (first attempt)" ) return web.json_response( { "message": "Access granted", "user": username, "resource": "protected data", } ) async def handle_flaky(self, request: web.Request) -> web.Response: """Endpoint that fails a few times before succeeding.""" self.flaky_counter += 1 # Fail the first 2 requests, succeed on the 3rd if self.flaky_counter <= 2: return web.Response( status=503, text=f"Service temporarily unavailable (attempt {self.flaky_counter})", ) # Reset counter and return success self.flaky_counter = 0 return web.json_response( { "message": "Success after retries!", "data": "Important information retrieved", } ) async def handle_always_fail(self, request: web.Request) -> web.Response: """Endpoint that always returns an error.""" return web.Response(status=500, text="Internal server error") async def handle_status(self, request: web.Request) -> web.Response: """Return the status code specified in the path.""" status = int(request.match_info["status"]) return web.Response(status=status, text=f"Status: {status}") async def run_test_server() -> web.AppRunner: """Run a test server with various endpoints.""" app = web.Application() server = TestServer() app.router.add_get("/protected", server.handle_protected) app.router.add_get("/flaky", server.handle_flaky) app.router.add_get("/always-fail", server.handle_always_fail) app.router.add_get("/status/{status}", server.handle_status) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "localhost", 8080) await site.start() return runner async def run_tests() -> None: """Run all the middleware tests.""" # Create middleware instances logging_middleware = LoggingMiddleware() auth_middleware = BasicAuthMiddleware("user", "pass") retry_middleware = RetryMiddleware(max_retries=2, initial_delay=0.5) # Combine middleware - order matters! # Applied in order: logging -> auth -> retry -> request async with ClientSession( middlewares=(logging_middleware, auth_middleware, retry_middleware) ) as session: print( "=== Test 1: Protected endpoint with auth (fails once, then succeeds) ===" ) print("This tests retry + auth working together...") async with session.get("http://localhost:8080/protected") as resp: if resp.status == 200: data = await resp.json() print(f"Success after retry! Response: {data}") else: print(f"Failed with status: {resp.status}") print("\n=== Test 2: Flaky endpoint (fails twice, then succeeds) ===") print("Watch the logs to see retries in action...") async with session.get("http://localhost:8080/flaky") as resp: if resp.status == 200: data = await resp.json() print(f"Success after retries! Response: {data}") else: text = await resp.text() print(f"Failed with status {resp.status}: {text}") print("\n=== Test 3: Always failing endpoint ===") async with session.get("http://localhost:8080/always-fail") as resp: print(f"Final status after retries: {resp.status}") print("\n=== Test 4: Non-retryable status (404) ===") async with session.get("http://localhost:8080/status/404") as resp: print(f"Status: {resp.status} (no retries for 404)") # Test without middleware for comparison print("\n=== Test 5: Request without middleware ===") print("Making a request to protected endpoint without middleware...") async with session.get( "http://localhost:8080/protected", middlewares=() ) as resp: print(f"Status without middleware: {resp.status}") if resp.status == 401: print("Failed as expected - no auth header added") async def main() -> None: # Start test server server = await run_test_server() try: await run_tests() finally: await server.cleanup() if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/curl.py ================================================ #!/usr/bin/env python3 import argparse import asyncio import sys import aiohttp async def curl(url: str) -> None: async with aiohttp.ClientSession() as session: async with session.request("GET", url) as response: print(repr(response)) chunk = await response.content.read() print("Downloaded: %s" % len(chunk)) if __name__ == "__main__": ARGS = argparse.ArgumentParser(description="GET url example") ARGS.add_argument("url", nargs=1, metavar="URL", help="URL to download") ARGS.add_argument( "--iocp", default=False, action="store_true", help="Use ProactorEventLoop on Windows", ) options = ARGS.parse_args() if options.iocp and sys.platform == "win32": from asyncio import events, windows_events # https://github.com/python/mypy/issues/12286 el = windows_events.ProactorEventLoop() # type: ignore[attr-defined] events.set_event_loop(el) loop = asyncio.get_event_loop() loop.run_until_complete(curl(options.url[0])) ================================================ FILE: examples/digest_auth_qop_auth.py ================================================ #!/usr/bin/env python3 """ Example of using digest authentication middleware with aiohttp client. This example shows how to use the DigestAuthMiddleware from aiohttp.client_middleware_digest_auth to authenticate with a server that requires digest authentication with different qop options. In this case, it connects to httpbin.org's digest auth endpoint. """ import asyncio from itertools import product from yarl import URL from aiohttp import ClientSession from aiohttp.client_middleware_digest_auth import DigestAuthMiddleware # Define QOP options available QOP_OPTIONS = ["auth", "auth-int"] # Algorithms supported by httpbin.org ALGORITHMS = ["MD5", "SHA-256", "SHA-512"] # Username and password for testing USERNAME = "my" PASSWORD = "dog" # All combinations of QOP options and algorithms TEST_COMBINATIONS = list(product(QOP_OPTIONS, ALGORITHMS)) async def main() -> None: # Create a DigestAuthMiddleware instance with appropriate credentials digest_auth = DigestAuthMiddleware(login=USERNAME, password=PASSWORD) # Create a client session with the digest auth middleware async with ClientSession(middlewares=(digest_auth,)) as session: # Test each combination of QOP and algorithm for qop, algorithm in TEST_COMBINATIONS: print(f"\n\n=== Testing with qop={qop}, algorithm={algorithm} ===\n") url = URL( f"https://httpbin.org/digest-auth/{qop}/{USERNAME}/{PASSWORD}/{algorithm}" ) async with session.get(url) as resp: print(f"Status: {resp.status}") print(f"Headers: {resp.headers}") # Parse the JSON response json_response = await resp.json() print(f"Response: {json_response}") # Verify authentication was successful if resp.status == 200: print("\nAuthentication successful!") print(f"Authenticated user: {json_response.get('user')}") print( f"Authentication method: {json_response.get('authenticated')}" ) else: print("\nAuthentication failed.") if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/fake_server.py ================================================ #!/usr/bin/env python3 import asyncio import pathlib import socket import ssl from aiohttp import ClientSession, TCPConnector, web from aiohttp.abc import AbstractResolver, ResolveResult from aiohttp.resolver import DefaultResolver class FakeResolver(AbstractResolver): _LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1", socket.AF_INET6: "::1"} def __init__(self, fakes: dict[str, int]) -> None: """fakes -- dns -> port dict""" self._fakes = fakes self._resolver = DefaultResolver() async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, ) -> list[ResolveResult]: fake_port = self._fakes.get(host) if fake_port is not None: return [ { "hostname": host, "host": self._LOCAL_HOST[family], "port": fake_port, "family": family, "proto": 0, "flags": socket.AI_NUMERICHOST, } ] else: return await self._resolver.resolve(host, port, family) async def close(self) -> None: await self._resolver.close() class FakeFacebook: def __init__(self) -> None: self.app = web.Application() self.app.router.add_routes( [ web.get("/v2.7/me", self.on_me), web.get("/v2.7/me/friends", self.on_my_friends), ] ) self.runner = web.AppRunner(self.app) here = pathlib.Path(__file__) ssl_cert = here.parent / "server.crt" ssl_key = here.parent / "server.key" self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.ssl_context.load_cert_chain(str(ssl_cert), str(ssl_key)) async def start(self) -> dict[str, int]: await self.runner.setup() site = web.TCPSite( self.runner, "127.0.0.1", port=0, ssl_context=self.ssl_context ) await site.start() return {"graph.facebook.com": site.port} async def stop(self) -> None: await self.runner.cleanup() async def on_me(self, request: web.Request) -> web.StreamResponse: return web.json_response({"name": "John Doe", "id": "12345678901234567"}) async def on_my_friends(self, request: web.Request) -> web.StreamResponse: return web.json_response( { "data": [ {"name": "Bill Doe", "id": "233242342342"}, {"name": "Mary Doe", "id": "2342342343222"}, {"name": "Alex Smith", "id": "234234234344"}, ], "paging": { "cursors": { "before": "QVFIUjRtc2c5NEl0ajN", "after": "QVFIUlpFQWM0TmVuaDRad0dt", }, "next": ( "https://graph.facebook.com/v2.7/12345678901234567/" "friends?access_token=EAACEdEose0cB" ), }, "summary": {"total_count": 3}, } ) async def main() -> None: token = "ER34gsSGGS34XCBKd7u" fake_facebook = FakeFacebook() info = await fake_facebook.start() resolver = FakeResolver(info) connector = TCPConnector(resolver=resolver, ssl=False) async with ClientSession(connector=connector) as session: async with session.get( "https://graph.facebook.com/v2.7/me", params={"access_token": token} ) as resp: print(await resp.json()) async with session.get( "https://graph.facebook.com/v2.7/me/friends", params={"access_token": token} ) as resp: print(await resp.json()) await fake_facebook.stop() asyncio.run(main()) ================================================ FILE: examples/logging_middleware.py ================================================ #!/usr/bin/env python3 """ Example of using logging middleware with aiohttp client. This example shows how to implement a middleware that logs request timing and response status. This is useful for debugging, monitoring, and understanding the flow of HTTP requests in your application. This example includes a test server with various endpoints. """ import asyncio import json import logging import time from collections.abc import Coroutine from typing import Any from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) _LOGGER = logging.getLogger(__name__) class LoggingMiddleware: """Middleware that logs request timing and response status.""" async def __call__( self, request: ClientRequest, handler: ClientHandlerType, ) -> ClientResponse: start_time = time.monotonic() # Log request _LOGGER.info("[REQUEST] %s %s", request.method, request.url) if request.headers: _LOGGER.debug("[REQUEST HEADERS] %s", request.headers) # Execute request response = await handler(request) # Log response duration = time.monotonic() - start_time _LOGGER.info( "[RESPONSE] %s %s - Status: %s - Duration: %.3fs", request.method, request.url, response.status, duration, ) _LOGGER.debug("[RESPONSE HEADERS] %s", response.headers) return response class TestServer: """Test server for logging middleware demo.""" async def handle_hello(self, request: web.Request) -> web.Response: """Simple hello endpoint.""" name = request.match_info.get("name", "World") return web.json_response({"message": f"Hello, {name}!"}) async def handle_slow(self, request: web.Request) -> web.Response: """Endpoint that simulates slow response.""" delay = float(request.match_info.get("delay", 1)) await asyncio.sleep(delay) return web.json_response({"message": "Slow response completed", "delay": delay}) async def handle_error(self, request: web.Request) -> web.Response: """Endpoint that returns an error.""" status = int(request.match_info.get("status", 500)) return web.Response(status=status, text=f"Error response with status {status}") async def handle_json_data(self, request: web.Request) -> web.Response: """Endpoint that echoes JSON data.""" try: data = await request.json() return web.json_response({"echo": data, "received_at": time.time()}) except json.JSONDecodeError: return web.json_response({"error": "Invalid JSON"}, status=400) async def run_test_server() -> web.AppRunner: """Run a simple test server.""" app = web.Application() server = TestServer() app.router.add_get("/hello", server.handle_hello) app.router.add_get("/hello/{name}", server.handle_hello) app.router.add_get("/slow/{delay}", server.handle_slow) app.router.add_get("/error/{status}", server.handle_error) app.router.add_post("/echo", server.handle_json_data) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "localhost", 8080) await site.start() return runner async def run_tests() -> None: """Run all the middleware tests.""" # Create logging middleware logging_middleware = LoggingMiddleware() # Use middleware in session async with ClientSession(middlewares=(logging_middleware,)) as session: # Test 1: Simple GET request print("\n=== Test 1: Simple GET request ===") async with session.get("http://localhost:8080/hello") as resp: data = await resp.json() print(f"Response: {data}") # Test 2: GET with parameter print("\n=== Test 2: GET with parameter ===") async with session.get("http://localhost:8080/hello/Alice") as resp: data = await resp.json() print(f"Response: {data}") # Test 3: Slow request print("\n=== Test 3: Slow request (2 seconds) ===") async with session.get("http://localhost:8080/slow/2") as resp: data = await resp.json() print(f"Response: {data}") # Test 4: Error response print("\n=== Test 4: Error response ===") async with session.get("http://localhost:8080/error/404") as resp: text = await resp.text() print(f"Response: {text}") # Test 5: POST with JSON data print("\n=== Test 5: POST with JSON data ===") payload = {"name": "Bob", "age": 30, "city": "New York"} async with session.post("http://localhost:8080/echo", json=payload) as resp: data = await resp.json() print(f"Response: {data}") # Test 6: Multiple concurrent requests print("\n=== Test 6: Multiple concurrent requests ===") coros: list[Coroutine[Any, Any, ClientResponse]] = [] for i in range(3): coro = session.get(f"http://localhost:8080/hello/User{i}") coros.append(coro) responses = await asyncio.gather(*coros) for i, resp in enumerate(responses): async with resp: data = await resp.json() print(f"Concurrent request {i}: {data}") async def main() -> None: # Start test server server = await run_test_server() try: await run_tests() finally: # Cleanup server await server.cleanup() if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/lowlevel_srv.py ================================================ import asyncio from aiohttp import web, web_request async def handler(request: web_request.BaseRequest) -> web.StreamResponse: return web.Response(text="OK") async def main(loop: asyncio.AbstractEventLoop) -> None: server = web.Server(handler) await loop.create_server(server, "127.0.0.1", 8080) print("======= Serving on http://127.0.0.1:8080/ ======") # pause here for very long time by serving HTTP requests and # waiting for keyboard interruption await asyncio.sleep(100 * 3600) loop = asyncio.get_event_loop() try: loop.run_until_complete(main(loop)) except KeyboardInterrupt: pass loop.close() ================================================ FILE: examples/retry_middleware.py ================================================ #!/usr/bin/env python3 """ Example of using retry middleware with aiohttp client. This example shows how to implement a middleware that automatically retries failed requests with exponential backoff. The middleware can be configured with custom retry statuses, maximum retries, and backoff parameters. This example includes a test server that simulates various HTTP responses and can return different status codes on sequential requests. """ import asyncio import logging from http import HTTPStatus from typing import TYPE_CHECKING from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) DEFAULT_RETRY_STATUSES: set[HTTPStatus] = { HTTPStatus.TOO_MANY_REQUESTS, HTTPStatus.INTERNAL_SERVER_ERROR, HTTPStatus.BAD_GATEWAY, HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.GATEWAY_TIMEOUT, } class RetryMiddleware: """Middleware that retries failed requests with exponential backoff.""" def __init__( self, max_retries: int = 3, retry_statuses: set[HTTPStatus] | None = None, initial_delay: float = 1.0, backoff_factor: float = 2.0, ) -> None: self.max_retries = max_retries self.retry_statuses = retry_statuses or DEFAULT_RETRY_STATUSES self.initial_delay = initial_delay self.backoff_factor = backoff_factor async def __call__( self, request: ClientRequest, handler: ClientHandlerType, ) -> ClientResponse: """Execute request with retry logic.""" last_response: ClientResponse | None = None delay = self.initial_delay for attempt in range(self.max_retries + 1): if attempt > 0: _LOGGER.info( "Retrying request to %s (attempt %s/%s)", request.url, attempt + 1, self.max_retries + 1, ) # Execute the request response = await handler(request) last_response = response # Check if we should retry if response.status not in self.retry_statuses: return response # Don't retry if we've exhausted attempts if attempt >= self.max_retries: _LOGGER.warning( "Max retries (%s) exceeded for %s", self.max_retries, request.url ) return response # Wait before retrying _LOGGER.debug("Waiting %ss before retry...", delay) await asyncio.sleep(delay) delay *= self.backoff_factor # Return the last response if TYPE_CHECKING: assert last_response is not None # Always set since we loop at least once return last_response class TestServer: """Test server with stateful endpoints for retry testing.""" def __init__(self) -> None: self.request_counters: dict[str, int] = {} self.status_sequences: dict[str, list[int]] = { "eventually-ok": [500, 503, 502, 200], # Fails 3 times, then succeeds "always-error": [500, 500, 500, 500], # Always fails "immediate-ok": [200], # Succeeds immediately "flaky": [503, 200], # Fails once, then succeeds } async def handle_status(self, request: web.Request) -> web.Response: """Return the status code specified in the path.""" status = int(request.match_info["status"]) return web.Response(status=status, text=f"Status: {status}") async def handle_status_sequence(self, request: web.Request) -> web.Response: """Return different status codes on sequential requests.""" path = request.path # Initialize counter for this path if needed if path not in self.request_counters: self.request_counters[path] = 0 # Get the status sequence for this path sequence_name = request.match_info["name"] if sequence_name not in self.status_sequences: return web.Response(status=404, text="Sequence not found") sequence = self.status_sequences[sequence_name] # Get the current status based on request count count = self.request_counters[path] if count < len(sequence): status = sequence[count] else: # After sequence ends, always return the last status status = sequence[-1] # Increment counter for next request self.request_counters[path] += 1 return web.Response( status=status, text=f"Request #{count + 1}: Status {status}" ) async def handle_delay(self, request: web.Request) -> web.Response: """Delay response by specified seconds.""" delay = float(request.match_info["delay"]) await asyncio.sleep(delay) return web.json_response({"delay": delay, "message": "Response after delay"}) async def handle_reset(self, request: web.Request) -> web.Response: """Reset request counters.""" self.request_counters = {} return web.Response(text="Counters reset") async def run_test_server() -> web.AppRunner: """Run a simple test server.""" app = web.Application() server = TestServer() app.router.add_get("/status/{status}", server.handle_status) app.router.add_get("/sequence/{name}", server.handle_status_sequence) app.router.add_get("/delay/{delay}", server.handle_delay) app.router.add_post("/reset", server.handle_reset) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "localhost", 8080) await site.start() return runner async def run_tests() -> None: """Run all retry middleware tests.""" # Create retry middleware with custom settings retry_middleware = RetryMiddleware( max_retries=3, retry_statuses=DEFAULT_RETRY_STATUSES, initial_delay=0.5, backoff_factor=2.0, ) async with ClientSession(middlewares=(retry_middleware,)) as session: # Reset counters before tests await session.post("http://localhost:8080/reset") # Test 1: Request that succeeds immediately print("=== Test 1: Immediate success ===") async with session.get("http://localhost:8080/sequence/immediate-ok") as resp: text = await resp.text() print(f"Final status: {resp.status}") print(f"Response: {text}") print("Success - no retries needed\n") # Test 2: Request that eventually succeeds after retries print("=== Test 2: Eventually succeeds (500->503->502->200) ===") async with session.get("http://localhost:8080/sequence/eventually-ok") as resp: text = await resp.text() print(f"Final status: {resp.status}") print(f"Response: {text}") if resp.status == 200: print("Success after retries!\n") else: print("Failed after retries\n") # Test 3: Request that always fails print("=== Test 3: Always fails (500->500->500->500) ===") async with session.get("http://localhost:8080/sequence/always-error") as resp: text = await resp.text() print(f"Final status: {resp.status}") print(f"Response: {text}") print("Failed after exhausting all retries\n") # Test 4: Flaky service (fails once then succeeds) print("=== Test 4: Flaky service (503->200) ===") await session.post("http://localhost:8080/reset") # Reset counters async with session.get("http://localhost:8080/sequence/flaky") as resp: text = await resp.text() print(f"Final status: {resp.status}") print(f"Response: {text}") print("Success after one retry!\n") # Test 5: Non-retryable status print("=== Test 5: Non-retryable status (404) ===") async with session.get("http://localhost:8080/status/404") as resp: print(f"Final status: {resp.status}") print("Failed immediately - not a retryable status\n") # Test 6: Delayed response print("=== Test 6: Testing with delay endpoint ===") try: async with session.get("http://localhost:8080/delay/0.5") as resp: print(f"Status: {resp.status}") data = await resp.json() print(f"Response received after delay: {data}\n") except asyncio.TimeoutError: print("Request timed out\n") async def main() -> None: # Start test server server = await run_test_server() try: await run_tests() finally: await server.cleanup() if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/server.crt ================================================ -----BEGIN CERTIFICATE----- MIIDADCCAegCCQCgevpPMuTTLzANBgkqhkiG9w0BAQsFADBCMQswCQYDVQQGEwJV QTEQMA4GA1UECAwHVWtyYWluZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ dHkgTHRkMB4XDTE2MDgwNzIzMTMwOFoXDTI2MDgwNTIzMTMwOFowQjELMAkGA1UE BhMCVUExEDAOBgNVBAgMB1VrcmFpbmUxITAfBgNVBAoMGEludGVybmV0IFdpZGdp dHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOUgkn3j X/sdg6GGueGDHCM+snIUVY3fM6D4jXjyBhnT3TqKG1lJwCGYR11AD+2SJYppU+w4 QaF6YZwMeZBKy+mVQ9+CrVYyKQE7j9H8XgNEHV9BQzoragT8lia8eC5aOQzUeX8A xCSSbsnyT/X+S1IKdd0txLOeZOD6pWwJoc3dpDELglk2b1tzhyN2GjQv3aRHj55P x7127MeZyRXwODFpXrpbnwih4OqkA4EYtmqFbZttGEzMhd4Y5mkbyuRbGM+IE99o QJMvnIkjAfUo0aKnDrcAIkWCkwLIci9TIG6u3R1P2Tn+HYVntzQZ4BnxanbFNQ5S 9ARd3529EmO3BzUCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAXyiw1+YUnTEDI3C/ vq1Vn9pnwZALVQPiPlTqEGkl/nbq0suMmeZZG7pwrOJp3wr+sGwRAv9sPTro6srf Vj12wTo4LrTRKEDuS+AUJl0Mut7cPGIUKo+MGeZmmnDjMqcjljN3AO47ef4eWYo5 XGe4r4NDABEk5auOD/vQW5IiIMdmWsaMJ+0mZNpAV2NhAD/6ia28VvSL/yuaNqDW TYTUYHWLH08H6M6qrQ7FdoIDyYR5siqBukQzeqlnuq45bQ3ViYttNIkzZN4jbWJV /MFYLuJQ/fNoalDIC+ec0EIa9NbrfpoocJ8h6HlmWOqkES4QpBSOrkVid64Cdy3P JgiEWg== -----END CERTIFICATE----- ================================================ FILE: examples/server.csr ================================================ -----BEGIN CERTIFICATE REQUEST----- MIIChzCCAW8CAQAwQjELMAkGA1UEBhMCVUExEDAOBgNVBAgMB1VrcmFpbmUxITAf BgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEB BQADggEPADCCAQoCggEBAOUgkn3jX/sdg6GGueGDHCM+snIUVY3fM6D4jXjyBhnT 3TqKG1lJwCGYR11AD+2SJYppU+w4QaF6YZwMeZBKy+mVQ9+CrVYyKQE7j9H8XgNE HV9BQzoragT8lia8eC5aOQzUeX8AxCSSbsnyT/X+S1IKdd0txLOeZOD6pWwJoc3d pDELglk2b1tzhyN2GjQv3aRHj55Px7127MeZyRXwODFpXrpbnwih4OqkA4EYtmqF bZttGEzMhd4Y5mkbyuRbGM+IE99oQJMvnIkjAfUo0aKnDrcAIkWCkwLIci9TIG6u 3R1P2Tn+HYVntzQZ4BnxanbFNQ5S9ARd3529EmO3BzUCAwEAAaAAMA0GCSqGSIb3 DQEBCwUAA4IBAQDO/PSd29KgisTdGXhntg7yBEhBAjsDW7uQCrdrPSZtFyN6wUHy /1yrrWe56ZuW8jpuP5tG0eTZ+0bT2RXIRot8a2Cc3eBhpoe8M3d84yXjKAoHutGE 5IK+TViQdvT3pT3a7pTmjlf8Ojq9tx+U2ckiz8Ccnjd9yM47M9NgMhrS1aBpVZSt gOD+zzrqMML4xks9id94H7bi9Tgs3AbEJIyDpBpoK6i4OvK7KTidCngCg80qmdTy bcScLapoy1Ped2BKKuxWdOOlP+mDJatc/pcfBLE13AncQjJgMerS9M5RWCBjmRow A+aB6fBEU8bOTrqCryfBeTiV6xzyDDcIXtc6 -----END CERTIFICATE REQUEST----- ================================================ FILE: examples/server.key ================================================ -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEA5SCSfeNf+x2DoYa54YMcIz6ychRVjd8zoPiNePIGGdPdOoob WUnAIZhHXUAP7ZIlimlT7DhBoXphnAx5kErL6ZVD34KtVjIpATuP0fxeA0QdX0FD OitqBPyWJrx4Llo5DNR5fwDEJJJuyfJP9f5LUgp13S3Es55k4PqlbAmhzd2kMQuC WTZvW3OHI3YaNC/dpEePnk/HvXbsx5nJFfA4MWleulufCKHg6qQDgRi2aoVtm20Y TMyF3hjmaRvK5FsYz4gT32hAky+ciSMB9SjRoqcOtwAiRYKTAshyL1Mgbq7dHU/Z Of4dhWe3NBngGfFqdsU1DlL0BF3fnb0SY7cHNQIDAQABAoIBAG9BJ6B03VADfrzZ vDwh+3Gpqd/2u6wNqvYIejk123yDATLBiJIMW3x0goJm7tT+V7gjeJqEnmmYEPlC nWxQxT6AOdq3iw8FgB+XGjhuAAA5/MEZ4VjHZ81QEGBytzBaosT2DqB6cMMJTz5D qEvb1Brb9WsWJCLLUFRloBkbfDOG9lMvt34ixYTTmqjsVj5WByD5BhzKH51OJ72L 00IYpvrsEOtSev1hNV4199CHPYE90T/YQVooRBiHtTcfN+/KNVJu6Rf/zcaJ3WMS 1l3MBI8HwMimjKKkbddpoMHyFMtSNmS9Yq+4a9w7XZo1F5rt88hYSCtAF8HRAarX 0VBCJmkCgYEA9HenBBnmfDoN857femzoTHdWQQrZQ4YPAKHvKPlcgudizE5tQbs0 iTpwm+IsecgJS2Rio7zY+P7A5nKFz3N5c0IX3smYo0J2PoakkLAm25KMxFZYBuz4 MFWVdfByAU7d28BdNfyOVbA2kU2eal9lJ0yPLpMLbH8+bbvw5uBS808CgYEA7++p ftwib3DvKWMpl6G5eA1C2xprdbE0jm2fSr3LYp/vZ4QN2V6kK2YIlyUqQvhYCnxX oIP3v2MWDRHKKwJtBWR4+t23PaDaSXS2Ifm0qhRxwSm/oqpAJQXbR7VzxXp4/4FP 1SgkLe51bubc4h+cDngqBLcplCanvj52CqhqzDsCgYAEIhG8zANNjl22BLWaiETV Jh9bMifCMH4IcLRuaOjbfbX55kmKlvOobkiBGi3OUUd28teIFSVF8GiqfL0uaLFg 9XkZ1yaxe+or3HLjz1aY171xhFQwqcj4aDoCqHIE+6Rclr/8raxqXnRNuJY5DivT okO5cdr7lpsjl83W2WwNmQKBgCPXi1xWChbXqgJmu8nY8NnMMVaFpdPY+t7j5U3G +GDtP1gZU/BKwP9yqInblWqXqp82X+isjg/a/2pIZAj0vdB2Z9Qh1sOwCau7cZG1 uZVGpI+UavojsJ1XOKCHrJmtZ/HTIVfYPT9XRdehSRHGYwuOS8iUi/ODqr8ymXOS IRINAoGBAMEmhTihgFz6Y8ezRK3QTubguehHZG1zIvtgVhOk+8hRUTSJPI9nBJPC 4gOZsPx4g2oLK6PiudPR79bhxRxPACCMnXkdwZ/8FaIdmvRHsWVs8T80wID0wthI r5hW4uqi9CcKZrGWH7mx9cVJktspeGUczvKyzNMfCaojwzA/49Z1 -----END RSA PRIVATE KEY----- ================================================ FILE: examples/server_simple.py ================================================ # server_simple.py from aiohttp import web async def handle(request: web.Request) -> web.StreamResponse: name = request.match_info.get("name", "Anonymous") text = "Hello, " + name return web.Response(text=text) async def wshandle(request: web.Request) -> web.StreamResponse: ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: if msg.type is web.WSMsgType.TEXT: await ws.send_str(f"Hello, {msg.data}") elif msg.type is web.WSMsgType.BINARY: await ws.send_bytes(msg.data) elif msg.type is web.WSMsgType.CLOSE: break return ws app = web.Application() app.add_routes( [web.get("/", handle), web.get("/echo", wshandle), web.get("/{name}", handle)] ) web.run_app(app) ================================================ FILE: examples/static_files.py ================================================ #!/usr/bin/env python3 import pathlib from aiohttp import web app = web.Application() app.router.add_static("/", pathlib.Path(__file__).parent, show_index=True) web.run_app(app) ================================================ FILE: examples/token_refresh_middleware.py ================================================ #!/usr/bin/env python3 """ Example of using token refresh middleware with aiohttp client. This example shows how to implement a middleware that handles JWT token refresh automatically. The middleware: - Adds bearer tokens to requests - Detects when tokens are expired - Automatically refreshes tokens when needed - Handles concurrent requests during token refresh This example includes a test server that simulates a JWT auth system. Note: This is a simplified example for demonstration purposes. In production, use proper JWT libraries and secure token storage. """ import asyncio import hashlib import json import logging import secrets import time from collections.abc import Coroutine from http import HTTPStatus from typing import TYPE_CHECKING, Any from aiohttp import ( ClientHandlerType, ClientRequest, ClientResponse, ClientSession, hdrs, web, ) logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) class TokenRefreshMiddleware: """Middleware that handles JWT token refresh automatically.""" def __init__(self, token_endpoint: str, refresh_token: str) -> None: self.token_endpoint = token_endpoint self.refresh_token = refresh_token self.access_token: str | None = None self.token_expires_at: float | None = None self._refresh_lock = asyncio.Lock() async def _refresh_access_token(self, session: ClientSession) -> str: """Refresh the access token using the refresh token.""" async with self._refresh_lock: # Check if another coroutine already refreshed the token if ( self.token_expires_at and time.time() < self.token_expires_at and self.access_token ): _LOGGER.debug("Token already refreshed by another request") return self.access_token _LOGGER.info("Refreshing access token...") # Make refresh request without middleware to avoid recursion async with session.post( self.token_endpoint, json={"refresh_token": self.refresh_token}, middlewares=(), # Disable middleware for this request ) as resp: resp.raise_for_status() data = await resp.json() if "access_token" not in data: raise ValueError("No access_token in refresh response") self.access_token = data["access_token"] # Token expires in 5 minutes for demo, refresh 30 seconds early expires_in = data.get("expires_in", 300) self.token_expires_at = time.time() + expires_in - 30 _LOGGER.info( "Token refreshed successfully, expires in %s seconds", expires_in ) if TYPE_CHECKING: assert self.access_token is not None # Just assigned above return self.access_token async def __call__( self, request: ClientRequest, handler: ClientHandlerType, ) -> ClientResponse: """Add auth token to request, refreshing if needed.""" # Skip token for refresh endpoint to avoid recursion if str(request.url).endswith("/token/refresh"): return await handler(request) # Refresh token if needed if not self.access_token or ( self.token_expires_at and time.time() >= self.token_expires_at ): await self._refresh_access_token(request.session) # Add token to request request.headers[hdrs.AUTHORIZATION] = f"Bearer {self.access_token}" _LOGGER.debug("Added Bearer token to request") # Execute request response = await handler(request) # If we get 401, try refreshing token once if response.status == HTTPStatus.UNAUTHORIZED: _LOGGER.info("Got 401, attempting token refresh...") await self._refresh_access_token(request.session) request.headers[hdrs.AUTHORIZATION] = f"Bearer {self.access_token}" response = await handler(request) return response class TestServer: """Test server with JWT-like token authentication.""" def __init__(self) -> None: self.tokens_db: dict[str, dict[str, str | float]] = {} self.refresh_tokens_db: dict[str, dict[str, str | float]] = { # Hash of refresh token -> user data hashlib.sha256(b"demo_refresh_token_12345").hexdigest(): { "user_id": "user123", "username": "testuser", "issued_at": time.time(), } } def generate_access_token(self) -> str: """Generate a secure random access token.""" return secrets.token_urlsafe(32) async def _process_token_refresh(self, data: dict[str, str]) -> web.Response: """Process the token refresh request.""" refresh_token = data.get("refresh_token") if not refresh_token: return web.json_response({"error": "refresh_token required"}, status=400) # Hash the refresh token to look it up refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest() if refresh_token_hash not in self.refresh_tokens_db: return web.json_response({"error": "Invalid refresh token"}, status=401) user_data = self.refresh_tokens_db[refresh_token_hash] # Generate new access token access_token = self.generate_access_token() expires_in = 300 # 5 minutes for demo # Store the access token with expiry token_hash = hashlib.sha256(access_token.encode()).hexdigest() self.tokens_db[token_hash] = { "user_id": user_data["user_id"], "username": user_data["username"], "expires_at": time.time() + expires_in, "issued_at": time.time(), } # Clean up expired tokens periodically current_time = time.time() self.tokens_db = { k: v for k, v in self.tokens_db.items() if isinstance(v["expires_at"], float) and v["expires_at"] > current_time } return web.json_response( { "access_token": access_token, "token_type": "Bearer", "expires_in": expires_in, } ) async def handle_token_refresh(self, request: web.Request) -> web.Response: """Handle token refresh requests.""" try: data = await request.json() return await self._process_token_refresh(data) except json.JSONDecodeError: return web.json_response({"error": "Invalid request"}, status=400) async def verify_bearer_token( self, request: web.Request ) -> dict[str, str | float] | None: """Verify bearer token and return user data if valid.""" auth_header = request.headers.get(hdrs.AUTHORIZATION, "") if not auth_header.startswith("Bearer "): return None token = auth_header[7:] # Remove "Bearer " token_hash = hashlib.sha256(token.encode()).hexdigest() # Check if token exists and is not expired if token_hash in self.tokens_db: token_data = self.tokens_db[token_hash] if ( isinstance(token_data["expires_at"], float) and token_data["expires_at"] > time.time() ): return token_data return None async def handle_protected_resource(self, request: web.Request) -> web.Response: """Protected endpoint that requires valid bearer token.""" user_data = await self.verify_bearer_token(request) if not user_data: return web.json_response({"error": "Invalid or expired token"}, status=401) return web.json_response( { "message": "Access granted to protected resource", "user": user_data["username"], "data": "Secret information", } ) async def handle_user_info(self, request: web.Request) -> web.Response: """Another protected endpoint.""" user_data = await self.verify_bearer_token(request) if not user_data: return web.json_response({"error": "Invalid or expired token"}, status=401) return web.json_response( { "user_id": user_data["user_id"], "username": user_data["username"], "email": f"{user_data['username']}@example.com", "roles": ["user", "admin"], } ) async def run_test_server() -> web.AppRunner: """Run a test server with JWT auth endpoints.""" test_server = TestServer() app = web.Application() app.router.add_post("/token/refresh", test_server.handle_token_refresh) app.router.add_get("/api/protected", test_server.handle_protected_resource) app.router.add_get("/api/user", test_server.handle_user_info) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "localhost", 8080) await site.start() return runner async def run_tests() -> None: """Run all token refresh middleware tests.""" # Create token refresh middleware # In a real app, this refresh token would be securely stored token_middleware = TokenRefreshMiddleware( token_endpoint="http://localhost:8080/token/refresh", refresh_token="demo_refresh_token_12345", ) async with ClientSession(middlewares=(token_middleware,)) as session: print("=== Test 1: First request (will trigger token refresh) ===") async with session.get("http://localhost:8080/api/protected") as resp: if resp.status == 200: data = await resp.json() print(f"Success! Response: {data}") else: print(f"Failed with status: {resp.status}") print("\n=== Test 2: Second request (uses cached token) ===") async with session.get("http://localhost:8080/api/user") as resp: if resp.status == 200: data = await resp.json() print(f"User info: {data}") else: print(f"Failed with status: {resp.status}") print("\n=== Test 3: Multiple concurrent requests ===") print("(Should only refresh token once)") coros: list[Coroutine[Any, Any, ClientResponse]] = [] for i in range(3): coro = session.get("http://localhost:8080/api/protected") coros.append(coro) responses = await asyncio.gather(*coros) for i, resp in enumerate(responses): async with resp: if resp.status == 200: print(f"Request {i + 1}: Success") else: print(f"Request {i + 1}: Failed with {resp.status}") print("\n=== Test 4: Simulate token expiry ===") # For demo purposes, force token expiry token_middleware.token_expires_at = time.time() - 1 print("Token expired, next request should trigger refresh...") async with session.get("http://localhost:8080/api/protected") as resp: if resp.status == 200: data = await resp.json() print(f"Success after token refresh! Response: {data}") else: print(f"Failed with status: {resp.status}") print("\n=== Test 5: Request without middleware (no auth) ===") # Make a request without any middleware to show the difference async with session.get( "http://localhost:8080/api/protected", middlewares=(), # Bypass all middleware for this request ) as resp: print(f"Status: {resp.status}") if resp.status == 401: error = await resp.json() print(f"Failed as expected without auth: {error}") async def main() -> None: # Start test server server = await run_test_server() try: await run_tests() finally: await server.cleanup() if __name__ == "__main__": asyncio.run(main()) ================================================ FILE: examples/web_classview.py ================================================ #!/usr/bin/env python3 """Example for aiohttp.web class based views.""" import functools import json from aiohttp import web class MyView(web.View): async def get(self) -> web.StreamResponse: return web.json_response( { "method": self.request.method, "args": dict(self.request.rel_url.query), "headers": dict(self.request.headers), }, dumps=functools.partial(json.dumps, indent=4), ) async def post(self) -> web.StreamResponse: data = await self.request.post() return web.json_response( { "method": self.request.method, "data": dict(data), "headers": dict(self.request.headers), }, dumps=functools.partial(json.dumps, indent=4), ) async def index(request: web.Request) -> web.StreamResponse: txt = """ Class based view example

    Class based view example

    • / This page
    • /get Returns GET data.
    • /post Returns POST data.
    """ return web.Response(text=txt, content_type="text/html") def init() -> web.Application: app = web.Application() app.router.add_get("/", index) app.router.add_get("/get", MyView) app.router.add_post("/post", MyView) return app web.run_app(init()) ================================================ FILE: examples/web_cookies.py ================================================ #!/usr/bin/env python3 """Example for aiohttp.web basic server with cookies.""" from pprint import pformat from typing import NoReturn from aiohttp import web tmpl = """\ Login
    Logout
    {}
    """ async def root(request: web.Request) -> web.StreamResponse: resp = web.Response(content_type="text/html") resp.text = tmpl.format(pformat(request.cookies)) return resp async def login(request: web.Request) -> NoReturn: exc = web.HTTPFound(location="/") exc.set_cookie("AUTH", "secret") raise exc async def logout(request: web.Request) -> NoReturn: exc = web.HTTPFound(location="/") exc.del_cookie("AUTH") raise exc def init() -> web.Application: app = web.Application() app.router.add_get("/", root) app.router.add_get("/login", login) app.router.add_get("/logout", logout) return app web.run_app(init()) ================================================ FILE: examples/web_rewrite_headers_middleware.py ================================================ #!/usr/bin/env python3 """Example for rewriting response headers by middleware.""" from aiohttp import web from aiohttp.typedefs import Handler async def handler(request: web.Request) -> web.StreamResponse: return web.Response(text="Everything is fine") async def middleware(request: web.Request, handler: Handler) -> web.StreamResponse: try: response = await handler(request) except web.HTTPException as exc: raise exc if not response.prepared: response.headers["SERVER"] = "Secured Server Software" return response def init() -> web.Application: app = web.Application(middlewares=[middleware]) app.router.add_get("/", handler) return app web.run_app(init()) ================================================ FILE: examples/web_srv.py ================================================ #!/usr/bin/env python3 """Example for aiohttp.web basic server.""" import textwrap from aiohttp import web async def intro(request: web.Request) -> web.StreamResponse: txt = textwrap.dedent("""\ Type {url}/hello/John {url}/simple or {url}/change_body in browser url bar """).format(url="127.0.0.1:8080") binary = txt.encode("utf8") resp = web.StreamResponse() resp.content_length = len(binary) resp.content_type = "text/plain" await resp.prepare(request) await resp.write(binary) return resp async def simple(request: web.Request) -> web.StreamResponse: return web.Response(text="Simple answer") async def change_body(request: web.Request) -> web.StreamResponse: resp = web.Response() resp.body = b"Body changed" resp.content_type = "text/plain" return resp async def hello(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() name = request.match_info.get("name", "Anonymous") answer = ("Hello, " + name).encode("utf8") resp.content_length = len(answer) resp.content_type = "text/plain" await resp.prepare(request) await resp.write(answer) await resp.write_eof() return resp def init() -> web.Application: app = web.Application() app.router.add_get("/", intro) app.router.add_get("/simple", simple) app.router.add_get("/change_body", change_body) app.router.add_get("/hello/{name}", hello) app.router.add_get("/hello", hello) return app web.run_app(init()) ================================================ FILE: examples/web_srv_route_deco.py ================================================ #!/usr/bin/env python3 """Example for aiohttp.web basic server with decorator definition for routes.""" import textwrap from aiohttp import web routes = web.RouteTableDef() @routes.get("/") async def intro(request: web.Request) -> web.StreamResponse: txt = textwrap.dedent("""\ Type {url}/hello/John {url}/simple or {url}/change_body in browser url bar """).format(url="127.0.0.1:8080") binary = txt.encode("utf8") resp = web.StreamResponse() resp.content_length = len(binary) resp.content_type = "text/plain" await resp.prepare(request) await resp.write(binary) return resp @routes.get("/simple") async def simple(request: web.Request) -> web.StreamResponse: return web.Response(text="Simple answer") @routes.get("/change_body") async def change_body(request: web.Request) -> web.StreamResponse: resp = web.Response() resp.body = b"Body changed" resp.content_type = "text/plain" return resp @routes.get("/hello") async def hello(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() name = request.match_info.get("name", "Anonymous") answer = ("Hello, " + name).encode("utf8") resp.content_length = len(answer) resp.content_type = "text/plain" await resp.prepare(request) await resp.write(answer) await resp.write_eof() return resp def init() -> web.Application: app = web.Application() app.router.add_routes(routes) return app web.run_app(init()) ================================================ FILE: examples/web_srv_route_table.py ================================================ #!/usr/bin/env python3 """Example for aiohttp.web basic server with table definition for routes.""" import textwrap from aiohttp import web async def intro(request: web.Request) -> web.StreamResponse: txt = textwrap.dedent("""\ Type {url}/hello/John {url}/simple or {url}/change_body in browser url bar """).format(url="127.0.0.1:8080") binary = txt.encode("utf8") resp = web.StreamResponse() resp.content_length = len(binary) resp.content_type = "text/plain" await resp.prepare(request) await resp.write(binary) return resp async def simple(request: web.Request) -> web.StreamResponse: return web.Response(text="Simple answer") async def change_body(request: web.Request) -> web.StreamResponse: resp = web.Response() resp.body = b"Body changed" resp.content_type = "text/plain" return resp async def hello(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() name = request.match_info.get("name", "Anonymous") answer = ("Hello, " + name).encode("utf8") resp.content_length = len(answer) resp.content_type = "text/plain" await resp.prepare(request) await resp.write(answer) await resp.write_eof() return resp def init() -> web.Application: app = web.Application() app.router.add_routes( [ web.get("/", intro), web.get("/simple", simple), web.get("/change_body", change_body), web.get("/hello/{name}", hello), web.get("/hello", hello), ] ) return app web.run_app(init()) ================================================ FILE: examples/web_ws.py ================================================ #!/usr/bin/env python3 """Example for aiohttp.web websocket server.""" # The extra strict mypy settings are here to help test that `Application[AppKey()]` # syntax is working correctly. A regression will cause mypy to raise an error. # mypy: disallow-any-expr, disallow-any-unimported, disallow-subclassing-any import os from aiohttp import web WS_FILE = os.path.join(os.path.dirname(__file__), "websocket.html") sockets = web.AppKey("sockets", list[web.WebSocketResponse]) async def wshandler(request: web.Request) -> web.WebSocketResponse | web.Response: resp = web.WebSocketResponse() available = resp.can_prepare(request) if not available: with open(WS_FILE, "rb") as fp: return web.Response(body=fp.read(), content_type="text/html") await resp.prepare(request) await resp.send_str("Welcome!!!") try: print("Someone joined.") for ws in request.app[sockets]: await ws.send_str("Someone joined") request.app[sockets].append(resp) async for msg in resp: if msg.type is web.WSMsgType.TEXT: for ws in request.app[sockets]: if ws is not resp: await ws.send_str(msg.data) else: return resp return resp finally: request.app[sockets].remove(resp) print("Someone disconnected.") for ws in request.app[sockets]: await ws.send_str("Someone disconnected.") async def on_shutdown(app: web.Application) -> None: for ws in app[sockets]: await ws.close() def init() -> web.Application: app = web.Application() l: list[web.WebSocketResponse] = [] app[sockets] = l app.router.add_get("/", wshandler) app.on_shutdown.append(on_shutdown) return app web.run_app(init()) ================================================ FILE: examples/websocket.html ================================================

    Chat!

     | Status: disconnected
    ================================================ FILE: pyproject.toml ================================================ [build-system] requires = [ "pkgconfig", # setuptools >= 67.0 required for Python 3.12+ support # Next step should be >= 77.0 for PEP 639 support # Don't bump too early to give distributors time to update # their setuptools version. "setuptools >= 67.0", ] build-backend = "setuptools.build_meta" [project] name = "aiohttp" # TODO: Update to just 'license = "..."' once setuptools is bumped to >=77 license = {text = "Apache-2.0 AND MIT"} description = "Async http client/server framework (asyncio)" readme = "README.rst" classifiers = [ "Development Status :: 5 - Production/Stable", "Framework :: AsyncIO", "Intended Audience :: Developers", "Operating System :: POSIX", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Topic :: Internet :: WWW/HTTP", ] requires-python = ">= 3.10" dependencies = [ "aiohappyeyeballs >= 2.5.0", "aiosignal >= 1.4.0", "async-timeout >= 4.0, < 6.0 ; python_version < '3.11'", "frozenlist >= 1.1.1", "multidict >=4.5, < 7.0", "propcache >= 0.2.0", "typing_extensions >= 4.4 ; python_version < '3.13'", "yarl >= 1.17.0, < 2.0", ] dynamic = [ "version", ] [project.optional-dependencies] speedups = [ "aiodns >= 3.3.0", "Brotli >= 1.2; platform_python_implementation == 'CPython'", "brotlicffi >= 1.2; platform_python_implementation != 'CPython'", "backports.zstd; platform_python_implementation == 'CPython' and python_version < '3.14'", ] [[project.maintainers]] name = "aiohttp team" email = "team@aiohttp.org" [project.urls] "Homepage" = "https://github.com/aio-libs/aiohttp" "Chat: Matrix" = "https://matrix.to/#/#aio-libs:matrix.org" "Chat: Matrix Space" = "https://matrix.to/#/#aio-libs-space:matrix.org" "CI: GitHub Actions" = "https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI" "Coverage: codecov" = "https://codecov.io/github/aio-libs/aiohttp" "Docs: Changelog" = "https://docs.aiohttp.org/en/stable/changes.html" "Docs: RTD" = "https://docs.aiohttp.org" "GitHub: issues" = "https://github.com/aio-libs/aiohttp/issues" "GitHub: repo" = "https://github.com/aio-libs/aiohttp" [tool.setuptools] license-files = [ # TODO: Use 'project.license-files' instead once setuptools is bumped to >=77 "LICENSE.txt", "vendor/llhttp/LICENSE", ] [tool.setuptools.dynamic] version = {attr = "aiohttp.__version__"} [tool.setuptools.packages.find] include = [ "aiohttp", "aiohttp.*", ] [tool.setuptools.exclude-package-data] "*" = ["*.c", "*.h"] [tool.towncrier] package = "aiohttp" filename = "CHANGES.rst" directory = "CHANGES/" title_format = "{version} ({project_date})" template = "CHANGES/.TEMPLATE.rst" issue_format = "{issue}" # NOTE: The types are declared because: # NOTE: - there is no mechanism to override just the value of # NOTE: `tool.towncrier.type.misc.showcontent`; # NOTE: - and, we want to declare extra non-default types for # NOTE: clarity and flexibility. [[tool.towncrier.section]] path = "" [[tool.towncrier.type]] # Something we deemed an improper undesired behavior that got corrected # in the release to match pre-agreed expectations. directory = "bugfix" name = "Bug fixes" showcontent = true [[tool.towncrier.type]] # New behaviors, public APIs. That sort of stuff. directory = "feature" name = "Features" showcontent = true [[tool.towncrier.type]] # Declarations of future API removals and breaking changes in behavior. directory = "deprecation" name = "Deprecations (removal in next major release)" showcontent = true [[tool.towncrier.type]] # When something public gets removed in a breaking way. Could be # deprecated in an earlier release. directory = "breaking" name = "Removals and backward incompatible breaking changes" showcontent = true [[tool.towncrier.type]] # Notable updates to the documentation structure or build process. directory = "doc" name = "Improved documentation" showcontent = true [[tool.towncrier.type]] # Notes for downstreams about unobvious side effects and tooling. Changes # in the test invocation considerations and runtime assumptions. directory = "packaging" name = "Packaging updates and notes for downstreams" showcontent = true [[tool.towncrier.type]] # Stuff that affects the contributor experience. e.g. Running tests, # building the docs, setting up the development environment. directory = "contrib" name = "Contributor-facing changes" showcontent = true [[tool.towncrier.type]] # Changes that are hard to assign to any of the above categories. directory = "misc" name = "Miscellaneous internal changes" showcontent = true [tool.cibuildwheel] test-command = "" # don't build PyPy wheels, install from source instead skip = "pp*" [tool.codespell] skip = '.git,*.pdf,*.svg,Makefile,CONTRIBUTORS.txt,venvs,_build' ignore-words-list = 'te,assertIn' [tool.slotscheck] # TODO(3.13): Remove aiohttp.helpers once https://github.com/python/cpython/pull/106771 # is available in all supported cpython versions exclude-modules = "(^aiohttp\\.helpers)" ================================================ FILE: requirements/base-ft.in ================================================ -r runtime-deps.in gunicorn ================================================ FILE: requirements/base-ft.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/base-ft.txt --strip-extras requirements/base-ft.in # aiodns==4.0.0 # via -r requirements/runtime-deps.in aiohappyeyeballs==2.6.1 # via -r requirements/runtime-deps.in aiosignal==1.4.0 # via -r requirements/runtime-deps.in async-timeout==5.0.1 ; python_version < "3.11" # via -r requirements/runtime-deps.in backports-zstd==1.3.0 ; platform_python_implementation == "CPython" and python_version < "3.14" # via -r requirements/runtime-deps.in brotli==1.2.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in cffi==2.0.0 # via pycares frozenlist==1.8.0 # via # -r requirements/runtime-deps.in # aiosignal gunicorn==25.1.0 # via -r requirements/base-ft.in idna==3.11 # via yarl multidict==6.7.1 # via # -r requirements/runtime-deps.in # yarl packaging==26.0 # via gunicorn propcache==0.4.1 # via # -r requirements/runtime-deps.in # yarl pycares==5.0.1 # via aiodns pycparser==3.0 # via cffi typing-extensions==4.15.0 ; python_version < "3.13" # via # -r requirements/runtime-deps.in # aiosignal # multidict yarl==1.22.0 # via -r requirements/runtime-deps.in ================================================ FILE: requirements/base.in ================================================ -r runtime-deps.in gunicorn uvloop; platform_system != "Windows" and implementation_name == "cpython" # MagicStack/uvloop#14 winloop; platform_system == "Windows" and implementation_name == "cpython" ================================================ FILE: requirements/base.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/base.txt --strip-extras requirements/base.in # aiodns==4.0.0 # via -r requirements/runtime-deps.in aiohappyeyeballs==2.6.1 # via -r requirements/runtime-deps.in aiosignal==1.4.0 # via -r requirements/runtime-deps.in async-timeout==5.0.1 ; python_version < "3.11" # via -r requirements/runtime-deps.in backports-zstd==1.3.0 ; platform_python_implementation == "CPython" and python_version < "3.14" # via -r requirements/runtime-deps.in brotli==1.2.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in cffi==2.0.0 # via pycares frozenlist==1.8.0 # via # -r requirements/runtime-deps.in # aiosignal gunicorn==25.1.0 # via -r requirements/base.in idna==3.11 # via yarl multidict==6.7.1 # via # -r requirements/runtime-deps.in # yarl packaging==26.0 # via gunicorn propcache==0.4.1 # via # -r requirements/runtime-deps.in # yarl pycares==5.0.1 # via aiodns pycparser==3.0 # via cffi typing-extensions==4.15.0 ; python_version < "3.13" # via # -r requirements/runtime-deps.in # aiosignal # multidict uvloop==0.21.0 ; platform_system != "Windows" and implementation_name == "cpython" # via -r requirements/base.in yarl==1.22.0 # via -r requirements/runtime-deps.in ================================================ FILE: requirements/constraints.in ================================================ -r cython.in -r dev.in -r doc-spelling.in -r lint.in ================================================ FILE: requirements/constraints.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/constraints.txt --strip-extras requirements/constraints.in # aiodns==4.0.0 # via # -r requirements/lint.in # -r requirements/runtime-deps.in aiohappyeyeballs==2.6.1 # via -r requirements/runtime-deps.in aiohttp-theme==0.1.7 # via -r requirements/doc.in aiosignal==1.4.0 # via -r requirements/runtime-deps.in alabaster==1.0.0 # via sphinx annotated-types==0.7.0 # via pydantic async-timeout==5.0.1 ; python_version < "3.11" # via # -r requirements/runtime-deps.in # valkey babel==2.18.0 # via sphinx backports-zstd==1.3.0 ; implementation_name == "cpython" # via # -r requirements/lint.in # -r requirements/runtime-deps.in blockbuster==1.5.26 # via # -r requirements/lint.in # -r requirements/test-common.in brotli==1.2.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in build==1.4.0 # via pip-tools certifi==2026.2.25 # via requests cffi==2.0.0 # via # cryptography # pycares # pytest-codspeed cfgv==3.5.0 # via pre-commit charset-normalizer==3.4.6 # via requests click==8.3.1 # via # pip-tools # slotscheck # towncrier # wait-for-it coverage==7.13.5 # via # -r requirements/test-common.in # pytest-cov cryptography==46.0.5 # via trustme cython==3.2.4 # via -r requirements/cython.in distlib==0.4.0 # via virtualenv docutils==0.21.2 # via sphinx exceptiongroup==1.3.1 # via pytest execnet==2.1.2 # via pytest-xdist filelock==3.25.2 # via # python-discovery # virtualenv forbiddenfruit==0.1.4 # via blockbuster freezegun==1.5.5 # via # -r requirements/lint.in # -r requirements/test-common.in frozenlist==1.8.0 # via # -r requirements/runtime-deps.in # aiosignal gunicorn==25.1.0 # via -r requirements/base.in identify==2.6.17 # via pre-commit idna==3.11 # via # requests # trustme # yarl imagesize==2.0.0 # via sphinx iniconfig==2.3.0 # via pytest isal==1.7.2 ; python_version < "3.14" # via # -r requirements/lint.in # -r requirements/test-common.in jinja2==3.1.6 # via # sphinx # towncrier librt==0.8.0 # via mypy markdown-it-py==4.0.0 # via rich markupsafe==3.0.3 # via jinja2 mdurl==0.1.2 # via markdown-it-py multidict==6.7.1 # via # -r requirements/multidict.in # -r requirements/runtime-deps.in # yarl mypy==1.19.1 ; implementation_name == "cpython" # via # -r requirements/lint.in # -r requirements/test-common.in mypy-extensions==1.1.0 # via mypy nodeenv==1.10.0 # via pre-commit packaging==26.0 # via # build # gunicorn # pytest # sphinx # wheel pathspec==1.0.4 # via mypy pip-tools==7.5.3 # via -r requirements/dev.in pkgconfig==1.6.0 # via -r requirements/test-common.in platformdirs==4.9.4 # via # python-discovery # virtualenv pluggy==1.6.0 # via # pytest # pytest-cov pre-commit==4.5.1 # via -r requirements/lint.in propcache==0.4.1 # via # -r requirements/runtime-deps.in # yarl proxy-py==2.4.10 # via # -r requirements/lint.in # -r requirements/test-common.in pycares==5.0.1 # via aiodns pycparser==3.0 # via cffi pydantic==2.12.5 # via python-on-whales pydantic-core==2.41.5 # via pydantic pyenchant==3.3.0 # via sphinxcontrib-spelling pygments==2.19.2 # via # pytest # rich # sphinx pyproject-hooks==1.2.0 # via # build # pip-tools pytest==9.0.2 # via # -r requirements/lint.in # -r requirements/test-common.in # pytest-codspeed # pytest-cov # pytest-mock # pytest-xdist pytest-codspeed==4.3.0 # via # -r requirements/lint.in # -r requirements/test-common.in pytest-cov==7.0.0 # via -r requirements/test-common.in pytest-mock==3.15.1 # via # -r requirements/lint.in # -r requirements/test-common.in pytest-xdist==3.8.0 # via -r requirements/test-common.in python-dateutil==2.9.0.post0 # via freezegun python-discovery==1.2.0 # via virtualenv python-on-whales==0.81.0 # via # -r requirements/lint.in # -r requirements/test-common.in pyyaml==6.0.3 # via pre-commit requests==2.32.5 # via # sphinx # sphinxcontrib-spelling rich==14.3.3 # via pytest-codspeed setuptools-git==1.2 # via -r requirements/test-common.in six==1.17.0 # via python-dateutil slotscheck==0.19.1 # via -r requirements/lint.in snowballstemmer==3.0.1 # via sphinx sphinx==8.1.3 # via # -r requirements/doc.in # sphinxcontrib-spelling # sphinxcontrib-towncrier sphinxcontrib-applehelp==2.0.0 # via sphinx sphinxcontrib-devhelp==2.0.0 # via sphinx sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx sphinxcontrib-spelling==8.0.2 ; platform_system != "Windows" # via -r requirements/doc-spelling.in sphinxcontrib-towncrier==0.5.0a0 # via -r requirements/doc.in tomli==2.4.0 # via # build # coverage # mypy # pip-tools # pytest # slotscheck # sphinx # towncrier towncrier==25.8.0 # via # -r requirements/doc.in # sphinxcontrib-towncrier trustme==1.2.1 ; platform_machine != "i686" # via # -r requirements/lint.in # -r requirements/test-common.in typing-extensions==4.15.0 ; python_version < "3.13" # via # -r requirements/runtime-deps.in # aiosignal # cryptography # exceptiongroup # multidict # mypy # pydantic # pydantic-core # python-on-whales # typing-inspection # virtualenv typing-inspection==0.4.2 # via pydantic urllib3==2.6.3 # via requests uvloop==0.21.0 ; platform_system != "Windows" # via # -r requirements/base.in # -r requirements/lint.in valkey==6.1.1 # via -r requirements/lint.in virtualenv==21.2.0 # via pre-commit wait-for-it==2.3.0 # via -r requirements/test-common.in wheel==0.46.3 # via pip-tools yarl==1.22.0 # via -r requirements/runtime-deps.in zlib-ng==1.0.0 # via # -r requirements/lint.in # -r requirements/test-common.in # The following packages are considered to be unsafe in a requirements file: pip==26.0.1 # via pip-tools setuptools==82.0.1 # via pip-tools ================================================ FILE: requirements/cython.in ================================================ -r multidict.in Cython ================================================ FILE: requirements/cython.txt ================================================ # # This file is autogenerated by pip-compile with python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/cython.txt --resolver=backtracking --strip-extras requirements/cython.in # cython==3.2.4 # via -r requirements/cython.in multidict==6.7.1 # via -r requirements/multidict.in typing-extensions==4.15.0 # via multidict ================================================ FILE: requirements/dev.in ================================================ -r lint.in -r test.in -r doc.in pip-tools ================================================ FILE: requirements/dev.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/dev.txt --strip-extras requirements/dev.in # aiodns==4.0.0 # via # -r requirements/lint.in # -r requirements/runtime-deps.in aiohappyeyeballs==2.6.1 # via -r requirements/runtime-deps.in aiohttp-theme==0.1.7 # via -r requirements/doc.in aiosignal==1.4.0 # via -r requirements/runtime-deps.in alabaster==1.0.0 # via sphinx annotated-types==0.7.0 # via pydantic async-timeout==5.0.1 ; python_version < "3.11" # via # -r requirements/runtime-deps.in # valkey babel==2.18.0 # via sphinx backports-zstd==1.3.0 ; platform_python_implementation == "CPython" and python_version < "3.14" # via # -r requirements/lint.in # -r requirements/runtime-deps.in blockbuster==1.5.26 # via # -r requirements/lint.in # -r requirements/test-common.in brotli==1.2.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in build==1.4.0 # via pip-tools certifi==2026.2.25 # via requests cffi==2.0.0 # via # cryptography # pycares # pytest-codspeed cfgv==3.5.0 # via pre-commit charset-normalizer==3.4.6 # via requests click==8.3.1 # via # pip-tools # slotscheck # towncrier # wait-for-it coverage==7.13.5 # via # -r requirements/test-common.in # pytest-cov cryptography==46.0.5 # via trustme distlib==0.4.0 # via virtualenv docutils==0.21.2 # via sphinx exceptiongroup==1.3.1 # via pytest execnet==2.1.2 # via pytest-xdist filelock==3.25.2 # via # python-discovery # virtualenv forbiddenfruit==0.1.4 # via blockbuster freezegun==1.5.5 # via # -r requirements/lint.in # -r requirements/test-common.in frozenlist==1.8.0 # via # -r requirements/runtime-deps.in # aiosignal gunicorn==25.1.0 # via -r requirements/base.in identify==2.6.17 # via pre-commit idna==3.11 # via # requests # trustme # yarl imagesize==2.0.0 # via sphinx iniconfig==2.3.0 # via pytest isal==1.7.2 ; python_version < "3.14" # via # -r requirements/lint.in # -r requirements/test-common.in jinja2==3.1.6 # via # sphinx # towncrier librt==0.8.0 # via mypy markdown-it-py==4.0.0 # via rich markupsafe==3.0.3 # via jinja2 mdurl==0.1.2 # via markdown-it-py multidict==6.7.1 # via # -r requirements/runtime-deps.in # yarl mypy==1.19.1 ; implementation_name == "cpython" # via # -r requirements/lint.in # -r requirements/test-common.in mypy-extensions==1.1.0 # via mypy nodeenv==1.10.0 # via pre-commit packaging==26.0 # via # build # gunicorn # pytest # sphinx # wheel pathspec==1.0.4 # via mypy pip-tools==7.5.3 # via -r requirements/dev.in pkgconfig==1.6.0 # via -r requirements/test-common.in platformdirs==4.9.4 # via # python-discovery # virtualenv pluggy==1.6.0 # via # pytest # pytest-cov pre-commit==4.5.1 # via -r requirements/lint.in propcache==0.4.1 # via # -r requirements/runtime-deps.in # yarl proxy-py==2.4.10 # via # -r requirements/lint.in # -r requirements/test-common.in pycares==5.0.1 # via aiodns pycparser==3.0 # via cffi pydantic==2.12.5 # via python-on-whales pydantic-core==2.41.5 # via pydantic pygments==2.19.2 # via # pytest # rich # sphinx pyproject-hooks==1.2.0 # via # build # pip-tools pytest==9.0.2 # via # -r requirements/lint.in # -r requirements/test-common.in # pytest-codspeed # pytest-cov # pytest-mock # pytest-xdist pytest-codspeed==4.3.0 # via # -r requirements/lint.in # -r requirements/test-common.in pytest-cov==7.0.0 # via -r requirements/test-common.in pytest-mock==3.15.1 # via # -r requirements/lint.in # -r requirements/test-common.in pytest-xdist==3.8.0 # via -r requirements/test-common.in python-dateutil==2.9.0.post0 # via freezegun python-discovery==1.2.0 # via virtualenv python-on-whales==0.81.0 # via # -r requirements/lint.in # -r requirements/test-common.in pyyaml==6.0.3 # via pre-commit requests==2.32.5 # via sphinx rich==14.3.3 # via pytest-codspeed setuptools-git==1.2 # via -r requirements/test-common.in six==1.17.0 # via python-dateutil slotscheck==0.19.1 # via -r requirements/lint.in snowballstemmer==3.0.1 # via sphinx sphinx==8.1.3 # via # -r requirements/doc.in # sphinxcontrib-towncrier sphinxcontrib-applehelp==2.0.0 # via sphinx sphinxcontrib-devhelp==2.0.0 # via sphinx sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx sphinxcontrib-towncrier==0.5.0a0 # via -r requirements/doc.in tomli==2.4.0 # via # build # coverage # mypy # pip-tools # pytest # slotscheck # sphinx # towncrier towncrier==25.8.0 # via # -r requirements/doc.in # sphinxcontrib-towncrier trustme==1.2.1 ; platform_machine != "i686" # via # -r requirements/lint.in # -r requirements/test-common.in typing-extensions==4.15.0 ; python_version < "3.13" # via # -r requirements/runtime-deps.in # aiosignal # cryptography # exceptiongroup # multidict # mypy # pydantic # pydantic-core # python-on-whales # typing-inspection # virtualenv typing-inspection==0.4.2 # via pydantic urllib3==2.6.3 # via requests uvloop==0.21.0 ; platform_system != "Windows" and implementation_name == "cpython" # via # -r requirements/base.in # -r requirements/lint.in valkey==6.1.1 # via -r requirements/lint.in virtualenv==21.2.0 # via pre-commit wait-for-it==2.3.0 # via -r requirements/test-common.in wheel==0.46.3 # via pip-tools yarl==1.22.0 # via -r requirements/runtime-deps.in zlib-ng==1.0.0 # via # -r requirements/lint.in # -r requirements/test-common.in # The following packages are considered to be unsafe in a requirements file: pip==26.0.1 # via pip-tools setuptools==82.0.1 # via pip-tools ================================================ FILE: requirements/doc-spelling.in ================================================ -r doc.in sphinxcontrib-spelling; platform_system!="Windows" # We only use it in GitHub Actions CI/CD ================================================ FILE: requirements/doc-spelling.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/doc-spelling.txt --strip-extras requirements/doc-spelling.in # aiohttp-theme==0.1.7 # via -r requirements/doc.in alabaster==1.0.0 # via sphinx babel==2.18.0 # via sphinx certifi==2026.2.25 # via requests charset-normalizer==3.4.6 # via requests click==8.3.1 # via towncrier docutils==0.21.2 # via sphinx idna==3.11 # via requests imagesize==2.0.0 # via sphinx jinja2==3.1.6 # via # sphinx # towncrier markupsafe==3.0.3 # via jinja2 packaging==26.0 # via sphinx pyenchant==3.3.0 # via sphinxcontrib-spelling pygments==2.19.2 # via sphinx requests==2.32.5 # via # sphinx # sphinxcontrib-spelling snowballstemmer==3.0.1 # via sphinx sphinx==8.1.3 # via # -r requirements/doc.in # sphinxcontrib-spelling # sphinxcontrib-towncrier sphinxcontrib-applehelp==2.0.0 # via sphinx sphinxcontrib-devhelp==2.0.0 # via sphinx sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx sphinxcontrib-spelling==8.0.2 ; platform_system != "Windows" # via -r requirements/doc-spelling.in sphinxcontrib-towncrier==0.5.0a0 # via -r requirements/doc.in tomli==2.4.0 # via # sphinx # towncrier towncrier==25.8.0 # via # -r requirements/doc.in # sphinxcontrib-towncrier urllib3==2.6.3 # via requests ================================================ FILE: requirements/doc.in ================================================ aiohttp-theme sphinx sphinxcontrib-towncrier towncrier ================================================ FILE: requirements/doc.txt ================================================ # # This file is autogenerated by pip-compile with python 3.10 # To update, run: # # pip-compile --allow-unsafe --output-file=requirements/doc.txt --resolver=backtracking --strip-extras requirements/doc.in # aiohttp-theme==0.1.7 # via -r requirements/doc.in alabaster==1.0.0 # via sphinx babel==2.18.0 # via sphinx certifi==2026.2.25 # via requests charset-normalizer==3.4.6 # via requests click==8.3.1 # via towncrier docutils==0.21.2 # via sphinx idna==3.11 # via requests imagesize==2.0.0 # via sphinx jinja2==3.1.6 # via # sphinx # towncrier markupsafe==3.0.3 # via jinja2 packaging==26.0 # via sphinx pygments==2.19.2 # via sphinx requests==2.32.5 # via sphinx snowballstemmer==3.0.1 # via sphinx sphinx==8.1.3 # via # -r requirements/doc.in # sphinxcontrib-towncrier sphinxcontrib-applehelp==2.0.0 # via sphinx sphinxcontrib-devhelp==2.0.0 # via sphinx sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx sphinxcontrib-towncrier==0.5.0a0 # via -r requirements/doc.in tomli==2.4.0 # via # sphinx # towncrier towncrier==25.8.0 # via # -r requirements/doc.in # sphinxcontrib-towncrier urllib3==2.6.3 # via requests ================================================ FILE: requirements/lint.in ================================================ aiodns backports.zstd; implementation_name == "cpython" blockbuster freezegun isal mypy; implementation_name == "cpython" pre-commit proxy.py pytest pytest-mock pytest_codspeed python-on-whales slotscheck trustme uvloop; platform_system != "Windows" valkey zlib_ng ================================================ FILE: requirements/lint.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/lint.txt --strip-extras requirements/lint.in # aiodns==4.0.0 # via -r requirements/lint.in annotated-types==0.7.0 # via pydantic async-timeout==5.0.1 # via valkey backports-zstd==1.3.0 ; implementation_name == "cpython" # via -r requirements/lint.in blockbuster==1.5.26 # via -r requirements/lint.in cffi==2.0.0 # via # cryptography # pycares # pytest-codspeed cfgv==3.5.0 # via pre-commit click==8.3.1 # via slotscheck cryptography==46.0.5 # via trustme distlib==0.4.0 # via virtualenv exceptiongroup==1.3.1 # via pytest filelock==3.25.2 # via # python-discovery # virtualenv forbiddenfruit==0.1.4 # via blockbuster freezegun==1.5.5 # via -r requirements/lint.in identify==2.6.17 # via pre-commit idna==3.11 # via trustme iniconfig==2.3.0 # via pytest isal==1.7.2 # via -r requirements/lint.in librt==0.8.0 # via mypy markdown-it-py==4.0.0 # via rich mdurl==0.1.2 # via markdown-it-py mypy==1.19.1 ; implementation_name == "cpython" # via -r requirements/lint.in mypy-extensions==1.1.0 # via mypy nodeenv==1.10.0 # via pre-commit packaging==26.0 # via pytest pathspec==1.0.4 # via mypy platformdirs==4.9.4 # via # python-discovery # virtualenv pluggy==1.6.0 # via pytest pre-commit==4.5.1 # via -r requirements/lint.in proxy-py==2.4.10 # via -r requirements/lint.in pycares==5.0.1 # via aiodns pycparser==3.0 # via cffi pydantic==2.12.5 # via python-on-whales pydantic-core==2.41.5 # via pydantic pygments==2.19.2 # via # pytest # rich pytest==9.0.2 # via # -r requirements/lint.in # pytest-codspeed # pytest-mock pytest-codspeed==4.3.0 # via -r requirements/lint.in pytest-mock==3.15.1 # via -r requirements/lint.in python-dateutil==2.9.0.post0 # via freezegun python-discovery==1.2.0 # via virtualenv python-on-whales==0.81.0 # via -r requirements/lint.in pyyaml==6.0.3 # via pre-commit rich==14.3.3 # via pytest-codspeed six==1.17.0 # via python-dateutil slotscheck==0.19.1 # via -r requirements/lint.in tomli==2.4.0 # via # mypy # pytest # slotscheck trustme==1.2.1 # via -r requirements/lint.in typing-extensions==4.15.0 # via # cryptography # exceptiongroup # mypy # pydantic # pydantic-core # python-on-whales # typing-inspection # virtualenv typing-inspection==0.4.2 # via pydantic uvloop==0.21.0 ; platform_system != "Windows" # via -r requirements/lint.in valkey==6.1.1 # via -r requirements/lint.in virtualenv==21.2.0 # via pre-commit zlib-ng==1.0.0 # via -r requirements/lint.in ================================================ FILE: requirements/multidict.in ================================================ multidict ================================================ FILE: requirements/multidict.txt ================================================ # # This file is autogenerated by pip-compile with python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/multidict.txt --resolver=backtracking --strip-extras requirements/multidict.in # multidict==6.7.1 # via -r requirements/multidict.in typing-extensions==4.15.0 # via multidict ================================================ FILE: requirements/runtime-deps.in ================================================ # Extracted from `pyproject.toml` via `make sync-direct-runtime-deps` aiodns >= 3.3.0 aiohappyeyeballs >= 2.5.0 aiosignal >= 1.4.0 async-timeout >= 4.0, < 6.0 ; python_version < '3.11' backports.zstd; platform_python_implementation == 'CPython' and python_version < '3.14' Brotli >= 1.2; platform_python_implementation == 'CPython' brotlicffi >= 1.2; platform_python_implementation != 'CPython' frozenlist >= 1.1.1 multidict >=4.5, < 7.0 propcache >= 0.2.0 typing_extensions >= 4.4 ; python_version < '3.13' yarl >= 1.17.0, < 2.0 ================================================ FILE: requirements/runtime-deps.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/runtime-deps.txt --strip-extras requirements/runtime-deps.in # aiodns==4.0.0 # via -r requirements/runtime-deps.in aiohappyeyeballs==2.6.1 # via -r requirements/runtime-deps.in aiosignal==1.4.0 # via -r requirements/runtime-deps.in async-timeout==5.0.1 ; python_version < "3.11" # via -r requirements/runtime-deps.in backports-zstd==1.3.0 ; platform_python_implementation == "CPython" and python_version < "3.14" # via -r requirements/runtime-deps.in brotli==1.2.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in cffi==2.0.0 # via pycares frozenlist==1.8.0 # via # -r requirements/runtime-deps.in # aiosignal idna==3.11 # via yarl multidict==6.7.1 # via # -r requirements/runtime-deps.in # yarl propcache==0.4.1 # via # -r requirements/runtime-deps.in # yarl pycares==5.0.1 # via aiodns pycparser==3.0 # via cffi typing-extensions==4.15.0 ; python_version < "3.13" # via # -r requirements/runtime-deps.in # aiosignal # multidict yarl==1.22.0 # via -r requirements/runtime-deps.in ================================================ FILE: requirements/sync-direct-runtime-deps.py ================================================ #!/usr/bin/env python """Sync direct runtime dependencies from pyproject.toml to runtime-deps.in.""" import sys from pathlib import Path if sys.version_info >= (3, 11): import tomllib else: raise RuntimeError("Use Python 3.11+ to run 'make sync-direct-runtime-deps'") data = tomllib.loads(Path("pyproject.toml").read_text()) reqs = ( data["project"]["dependencies"] + data["project"]["optional-dependencies"]["speedups"] ) reqs = sorted(reqs, key=str.casefold) with open(Path("requirements", "runtime-deps.in"), "w") as outfile: header = "# Extracted from `pyproject.toml` via `make sync-direct-runtime-deps`\n\n" outfile.write(header) outfile.write("\n".join(reqs) + "\n") ================================================ FILE: requirements/test-common.in ================================================ blockbuster coverage freezegun isal; python_version < "3.14" # no wheel for 3.14 mypy; implementation_name == "cpython" pkgconfig proxy.py >= 2.4.4rc5 pytest pytest-cov pytest-mock pytest-xdist pytest_codspeed python-on-whales setuptools-git trustme; platform_machine != "i686" # no 32-bit wheels wait-for-it zlib_ng ================================================ FILE: requirements/test-common.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/test-common.txt --strip-extras requirements/test-common.in # annotated-types==0.7.0 # via pydantic blockbuster==1.5.26 # via -r requirements/test-common.in cffi==2.0.0 # via # cryptography # pytest-codspeed click==8.3.1 # via wait-for-it coverage==7.13.5 # via # -r requirements/test-common.in # pytest-cov cryptography==46.0.5 # via trustme exceptiongroup==1.3.1 # via pytest execnet==2.1.2 # via pytest-xdist forbiddenfruit==0.1.4 # via blockbuster freezegun==1.5.5 # via -r requirements/test-common.in idna==3.11 # via trustme iniconfig==2.3.0 # via pytest isal==1.8.0 ; python_version < "3.14" # via -r requirements/test-common.in librt==0.8.0 # via mypy markdown-it-py==4.0.0 # via rich mdurl==0.1.2 # via markdown-it-py mypy==1.19.1 ; implementation_name == "cpython" # via -r requirements/test-common.in mypy-extensions==1.1.0 # via mypy packaging==26.0 # via pytest pathspec==1.0.4 # via mypy pkgconfig==1.6.0 # via -r requirements/test-common.in pluggy==1.6.0 # via # pytest # pytest-cov proxy-py==2.4.10 # via -r requirements/test-common.in pycparser==3.0 # via cffi pydantic==2.12.5 # via python-on-whales pydantic-core==2.41.5 # via pydantic pygments==2.19.2 # via # pytest # rich pytest==9.0.2 # via # -r requirements/test-common.in # pytest-codspeed # pytest-cov # pytest-mock # pytest-xdist pytest-codspeed==4.3.0 # via -r requirements/test-common.in pytest-cov==7.0.0 # via -r requirements/test-common.in pytest-mock==3.15.1 # via -r requirements/test-common.in pytest-xdist==3.8.0 # via -r requirements/test-common.in python-dateutil==2.9.0.post0 # via freezegun python-on-whales==0.81.0 # via -r requirements/test-common.in rich==14.3.3 # via pytest-codspeed setuptools-git==1.2 # via -r requirements/test-common.in six==1.17.0 # via python-dateutil tomli==2.4.0 # via # coverage # mypy # pytest trustme==1.2.1 ; platform_machine != "i686" # via -r requirements/test-common.in typing-extensions==4.15.0 # via # cryptography # exceptiongroup # mypy # pydantic # pydantic-core # python-on-whales # typing-inspection typing-inspection==0.4.2 # via pydantic wait-for-it==2.3.0 # via -r requirements/test-common.in zlib-ng==1.0.0 # via -r requirements/test-common.in ================================================ FILE: requirements/test-ft.in ================================================ -r base-ft.in -r test-common.in ================================================ FILE: requirements/test-ft.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/test-ft.txt --strip-extras requirements/test-ft.in # aiodns==4.0.0 # via -r requirements/runtime-deps.in aiohappyeyeballs==2.6.1 # via -r requirements/runtime-deps.in aiosignal==1.4.0 # via -r requirements/runtime-deps.in annotated-types==0.7.0 # via pydantic async-timeout==5.0.1 ; python_version < "3.11" # via -r requirements/runtime-deps.in backports-zstd==1.3.0 ; platform_python_implementation == "CPython" and python_version < "3.14" # via -r requirements/runtime-deps.in blockbuster==1.5.26 # via -r requirements/test-common.in brotli==1.2.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in cffi==2.0.0 # via # cryptography # pycares # pytest-codspeed click==8.3.1 # via wait-for-it coverage==7.13.5 # via # -r requirements/test-common.in # pytest-cov cryptography==46.0.5 # via trustme exceptiongroup==1.3.1 # via pytest execnet==2.1.2 # via pytest-xdist forbiddenfruit==0.1.4 # via blockbuster freezegun==1.5.5 # via -r requirements/test-common.in frozenlist==1.8.0 # via # -r requirements/runtime-deps.in # aiosignal gunicorn==25.1.0 # via -r requirements/base-ft.in idna==3.11 # via # trustme # yarl iniconfig==2.3.0 # via pytest isal==1.8.0 ; python_version < "3.14" # via -r requirements/test-common.in librt==0.8.0 # via mypy markdown-it-py==4.0.0 # via rich mdurl==0.1.2 # via markdown-it-py multidict==6.7.1 # via # -r requirements/runtime-deps.in # yarl mypy==1.19.1 ; implementation_name == "cpython" # via -r requirements/test-common.in mypy-extensions==1.1.0 # via mypy packaging==26.0 # via # gunicorn # pytest pathspec==1.0.4 # via mypy pkgconfig==1.6.0 # via -r requirements/test-common.in pluggy==1.6.0 # via # pytest # pytest-cov propcache==0.4.1 # via # -r requirements/runtime-deps.in # yarl proxy-py==2.4.10 # via -r requirements/test-common.in pycares==5.0.1 # via aiodns pycparser==3.0 # via cffi pydantic==2.12.5 # via python-on-whales pydantic-core==2.41.5 # via pydantic pygments==2.19.2 # via # pytest # rich pytest==9.0.2 # via # -r requirements/test-common.in # pytest-codspeed # pytest-cov # pytest-mock # pytest-xdist pytest-codspeed==4.3.0 # via -r requirements/test-common.in pytest-cov==7.0.0 # via -r requirements/test-common.in pytest-mock==3.15.1 # via -r requirements/test-common.in pytest-xdist==3.8.0 # via -r requirements/test-common.in python-dateutil==2.9.0.post0 # via freezegun python-on-whales==0.81.0 # via -r requirements/test-common.in rich==14.3.3 # via pytest-codspeed setuptools-git==1.2 # via -r requirements/test-common.in six==1.17.0 # via python-dateutil tomli==2.4.0 # via # coverage # mypy # pytest trustme==1.2.1 ; platform_machine != "i686" # via -r requirements/test-common.in typing-extensions==4.15.0 ; python_version < "3.13" # via # -r requirements/runtime-deps.in # aiosignal # cryptography # exceptiongroup # multidict # mypy # pydantic # pydantic-core # python-on-whales # typing-inspection typing-inspection==0.4.2 # via pydantic wait-for-it==2.3.0 # via -r requirements/test-common.in yarl==1.22.0 # via -r requirements/runtime-deps.in zlib-ng==1.0.0 # via -r requirements/test-common.in ================================================ FILE: requirements/test.in ================================================ -r base.in -r test-common.in ================================================ FILE: requirements/test.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/test.txt --strip-extras requirements/test.in # aiodns==4.0.0 # via -r requirements/runtime-deps.in aiohappyeyeballs==2.6.1 # via -r requirements/runtime-deps.in aiosignal==1.4.0 # via -r requirements/runtime-deps.in annotated-types==0.7.0 # via pydantic async-timeout==5.0.1 ; python_version < "3.11" # via -r requirements/runtime-deps.in backports-zstd==1.3.0 ; platform_python_implementation == "CPython" and python_version < "3.14" # via -r requirements/runtime-deps.in blockbuster==1.5.26 # via -r requirements/test-common.in brotli==1.2.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in cffi==2.0.0 # via # cryptography # pycares # pytest-codspeed click==8.3.1 # via wait-for-it coverage==7.13.5 # via # -r requirements/test-common.in # pytest-cov cryptography==46.0.5 # via trustme exceptiongroup==1.3.1 # via pytest execnet==2.1.2 # via pytest-xdist forbiddenfruit==0.1.4 # via blockbuster freezegun==1.5.5 # via -r requirements/test-common.in frozenlist==1.8.0 # via # -r requirements/runtime-deps.in # aiosignal gunicorn==25.1.0 # via -r requirements/base.in idna==3.11 # via # trustme # yarl iniconfig==2.3.0 # via pytest isal==1.7.2 ; python_version < "3.14" # via -r requirements/test-common.in librt==0.8.0 # via mypy markdown-it-py==4.0.0 # via rich mdurl==0.1.2 # via markdown-it-py multidict==6.7.1 # via # -r requirements/runtime-deps.in # yarl mypy==1.19.1 ; implementation_name == "cpython" # via -r requirements/test-common.in mypy-extensions==1.1.0 # via mypy packaging==26.0 # via # gunicorn # pytest pathspec==1.0.4 # via mypy pkgconfig==1.6.0 # via -r requirements/test-common.in pluggy==1.6.0 # via # pytest # pytest-cov propcache==0.4.1 # via # -r requirements/runtime-deps.in # yarl proxy-py==2.4.10 # via -r requirements/test-common.in pycares==5.0.1 # via aiodns pycparser==3.0 # via cffi pydantic==2.12.5 # via python-on-whales pydantic-core==2.41.5 # via pydantic pygments==2.19.2 # via # pytest # rich pytest==9.0.2 # via # -r requirements/test-common.in # pytest-codspeed # pytest-cov # pytest-mock # pytest-xdist pytest-codspeed==4.3.0 # via -r requirements/test-common.in pytest-cov==7.0.0 # via -r requirements/test-common.in pytest-mock==3.15.1 # via -r requirements/test-common.in pytest-xdist==3.8.0 # via -r requirements/test-common.in python-dateutil==2.9.0.post0 # via freezegun python-on-whales==0.81.0 # via -r requirements/test-common.in rich==14.3.3 # via pytest-codspeed setuptools-git==1.2 # via -r requirements/test-common.in six==1.17.0 # via python-dateutil tomli==2.4.0 # via # coverage # mypy # pytest trustme==1.2.1 ; platform_machine != "i686" # via -r requirements/test-common.in typing-extensions==4.15.0 ; python_version < "3.13" # via # -r requirements/runtime-deps.in # aiosignal # cryptography # exceptiongroup # multidict # mypy # pydantic # pydantic-core # python-on-whales # typing-inspection typing-inspection==0.4.2 # via pydantic uvloop==0.21.0 ; platform_system != "Windows" and implementation_name == "cpython" # via -r requirements/base.in wait-for-it==2.3.0 # via -r requirements/test-common.in yarl==1.22.0 # via -r requirements/runtime-deps.in zlib-ng==1.0.0 # via -r requirements/test-common.in ================================================ FILE: setup.cfg ================================================ [pep8] max-line-length=79 [easy_install] zip_ok = false [flake8] extend-select = B950, # NIC001 -- "Implicitly concatenated str literals on one line" NIC001, # NIC101 -- "Implicitly concatenated bytes literals on one line" NIC101, # TODO: don't disable D*, fix up issues instead ignore = N801,N802,N803,NIC002,NIC102,E203,E226,E305,W504,E252,E301,E302,E501,E704,W503,W504,D1,D4 max-line-length = 88 per-file-ignores = # I900: Shouldn't appear in requirements for examples. examples/*:I900 docs/code/*:F841 # flake8-requirements known-modules = proxy.py:[proxy] requirements-file = requirements/test.in requirements-max-depth = 4 [isort] line_length=88 include_trailing_comma=True multi_line_output=3 force_grid_wrap=0 combine_as_imports=True known_third_party=jinja2,pytest,multidict,yarl,gunicorn,freezegun known_first_party=aiohttp,aiohttp_jinja2,aiopg [report] exclude_lines = @abc.abstractmethod @abstractmethod [tool:pytest] addopts = # `pytest-xdist`: --numprocesses=auto # show 10 slowest invocations: --durations=10 # a bit of verbosity doesn't hurt: -v # report all the things == -rxXs: -ra # show values of the local vars in errors: --showlocals # `pytest-cov`: -p pytest_cov --cov=aiohttp --cov=tests/ -m "not dev_mode and not autobahn and not internal" filterwarnings = error ignore:module 'ssl' has no attribute 'OP_NO_COMPRESSION'. The Python interpreter is compiled against OpenSSL < 1.0.0. Ref. https.//docs.python.org/3/library/ssl.html#ssl.OP_NO_COMPRESSION:UserWarning ignore:Unclosed client session 2022.06.15`. ignore:path is deprecated. Use files.. instead. Refer to https.//importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.:DeprecationWarning:certifi.core # Dateutil deprecation warning already fixed upstream. # Can be dropped with the next release, `dateutil > 2.8.2` # https://github.com/dateutil/dateutil/pull/1285 ignore:datetime.*utcfromtimestamp\(\) is deprecated and scheduled for removal:DeprecationWarning:dateutil.tz.tz # Tracked upstream and waiting for PR review # https://github.com/spulec/freezegun/issues/508 # https://github.com/spulec/freezegun/pull/511 ignore:datetime.*utcnow\(\) is deprecated and scheduled for removal:DeprecationWarning:freezegun.api junit_suite_name = aiohttp_test_suite norecursedirs = dist docs build .tox .eggs minversion = 3.8.2 testpaths = tests/ xfail_strict = true markers = autobahn: Autobahn testsuite. Should be run as a separate job. dev_mode: mark test to run in dev mode. internal: tests which may cause issues for packagers, but should be run in aiohttp's CI. skip_blockbuster: mark test to skip the blockbuster fixture. ================================================ FILE: setup.py ================================================ import os import pathlib import sys from setuptools import Extension, setup if sys.version_info < (3, 10): raise RuntimeError("aiohttp 4.x requires Python 3.10+") USE_SYSTEM_DEPS = bool( os.environ.get("AIOHTTP_USE_SYSTEM_DEPS", os.environ.get("USE_SYSTEM_DEPS")) ) NO_EXTENSIONS: bool = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) HERE = pathlib.Path(__file__).parent IS_GIT_REPO = (HERE / ".git").exists() if sys.implementation.name != "cpython": NO_EXTENSIONS = True if ( not USE_SYSTEM_DEPS and IS_GIT_REPO and not (HERE / "vendor/llhttp/README.md").exists() ): print("Install submodules when building from git clone", file=sys.stderr) print("Hint:", file=sys.stderr) print(" git submodule update --init", file=sys.stderr) sys.exit(2) # NOTE: makefile cythonizes all Cython modules if USE_SYSTEM_DEPS: import shlex import pkgconfig llhttp_sources = [] llhttp_kwargs = { "extra_compile_args": shlex.split(pkgconfig.cflags("libllhttp")), "extra_link_args": shlex.split(pkgconfig.libs("libllhttp")), } else: llhttp_sources = [ "vendor/llhttp/build/c/llhttp.c", "vendor/llhttp/src/native/api.c", "vendor/llhttp/src/native/http.c", ] llhttp_kwargs = { "define_macros": [("LLHTTP_STRICT_MODE", 0)], "include_dirs": ["vendor/llhttp/build"], } extensions = [ Extension("aiohttp._websocket.mask", ["aiohttp/_websocket/mask.c"]), Extension( "aiohttp._http_parser", [ "aiohttp/_http_parser.c", "aiohttp/_find_header.c", *llhttp_sources, ], **llhttp_kwargs, ), Extension("aiohttp._http_writer", ["aiohttp/_http_writer.c"]), Extension("aiohttp._websocket.reader_c", ["aiohttp/_websocket/reader_c.c"]), ] build_type = "Pure" if NO_EXTENSIONS else "Accelerated" setup_kwargs = {} if NO_EXTENSIONS else {"ext_modules": extensions} print("*********************", file=sys.stderr) print("* {build_type} build *".format_map(locals()), file=sys.stderr) print("*********************", file=sys.stderr) setup(**setup_kwargs) ================================================ FILE: tests/autobahn/Dockerfile.aiohttp ================================================ FROM python:3.14 COPY ./ /src WORKDIR /src RUN pip install . ================================================ FILE: tests/autobahn/Dockerfile.autobahn ================================================ FROM crossbario/autobahn-testsuite:25.10.1 RUN apt-get update && apt-get install python3 python3-pip -y RUN pip3 install wait-for-it CMD ["wstest", "--mode", "fuzzingserver", "--spec", "/config/fuzzingserver.json"] ================================================ FILE: tests/autobahn/client/client.py ================================================ #!/usr/bin/env python3 import asyncio from aiohttp import ClientSession, WSMsgType async def client(url: str, name: str) -> None: async with ClientSession(base_url=url) as session: async with session.ws_connect("/getCaseCount") as ws: msg = await ws.receive() assert msg.type is WSMsgType.TEXT num_tests = int(msg.data) for i in range(1, num_tests + 1): async with session.ws_connect( "/runCase", params={"case": i, "agent": name} ) as ws: async for msg in ws: if msg.type is WSMsgType.TEXT: await ws.send_str(msg.data) elif msg.type is WSMsgType.BINARY: await ws.send_bytes(msg.data) else: break async with session.ws_connect("/updateReports", params={"agent": name}) as ws: pass if __name__ == "__main__": # pragma: no branch asyncio.run(client("http://localhost:9001", "aiohttp")) ================================================ FILE: tests/autobahn/client/fuzzingserver.json ================================================ { "url": "ws://localhost:9001", "options": {"failByDrop": true}, "outdir": "./reports/clients", "webport": 8080, "cases": ["*"], "exclude-cases": [], "exclude-agent-cases": {} } ================================================ FILE: tests/autobahn/server/fuzzingclient.json ================================================ { "options": {"failByDrop": true}, "outdir": "./reports/servers", "servers": [ {"agent": "AutobahnServer", "url": "ws://localhost:9001"} ], "cases": ["*"], "exclude-cases": [], "exclude-agent-cases": {} } ================================================ FILE: tests/autobahn/server/server.py ================================================ #!/usr/bin/env python3 import logging from aiohttp import WSCloseCode, web websockets = web.AppKey("websockets", list[web.WebSocketResponse]) async def wshandler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoclose=False) await ws.prepare(request) request.app[websockets].append(ws) async for msg in ws: if msg.type is web.WSMsgType.TEXT: await ws.send_str(msg.data) elif msg.type is web.WSMsgType.BINARY: await ws.send_bytes(msg.data) else: break return ws async def on_shutdown(app: web.Application) -> None: for ws in app[websockets]: await ws.close(code=WSCloseCode.GOING_AWAY, message=b"Server shutdown") if __name__ == "__main__": # pragma: no branch logging.basicConfig( level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" ) app = web.Application() app[websockets] = [] app.router.add_route("GET", "/", wshandler) app.on_shutdown.append(on_shutdown) web.run_app(app, port=9001) ================================================ FILE: tests/autobahn/test_autobahn.py ================================================ import json import pprint import subprocess from collections.abc import Iterator from pathlib import Path from typing import TYPE_CHECKING import pytest from pytest import TempPathFactory if TYPE_CHECKING: from python_on_whales import DockerException, docker else: python_on_whales = pytest.importorskip("python_on_whales") DockerException = python_on_whales.DockerException docker = python_on_whales.docker # (Test number, test status, test report) Result = tuple[str, str, dict[str, object] | None] @pytest.fixture(scope="session") def report_dir(tmp_path_factory: TempPathFactory) -> Path: return tmp_path_factory.mktemp("reports") @pytest.fixture(scope="session", autouse=True) def build_autobahn_testsuite() -> Iterator[None]: docker.build( file="tests/autobahn/Dockerfile.autobahn", tags=["autobahn-testsuite"], context_path=".", ) try: yield finally: docker.image.remove(x="autobahn-testsuite") def get_report(path: Path, result: dict[str, str]) -> dict[str, object] | None: if result["behaviorClose"] == "OK": return None return json.loads((path / result["reportfile"]).read_text()) # type: ignore[no-any-return] def get_test_results(path: Path, name: str) -> tuple[Result, ...]: results = json.loads((path / "index.json").read_text())[name] return tuple( (k, r["behaviorClose"], get_report(path, r)) for k, r in results.items() ) def process_xfail( results: tuple[Result, ...], xfail: dict[str, str] ) -> list[dict[str, object]]: failed = [] for number, status, details in results: if number in xfail: assert status not in {"OK", "INFORMATIONAL"} # Strict xfail assert details is not None if details["result"] == xfail[number]: continue if status not in {"OK", "INFORMATIONAL"}: # pragma: no cover assert details is not None pprint.pprint(details) failed.append(details) return failed @pytest.mark.autobahn def test_client(report_dir: Path, request: pytest.FixtureRequest) -> None: client = subprocess.Popen( ( "wait-for-it", "-s", "localhost:9001", "--", "coverage", "run", "-a", "tests/autobahn/client/client.py", ) ) try: autobahn_container = docker.run( detach=True, image="autobahn-testsuite", name="autobahn", publish=[(9001, 9001)], remove=True, volumes=[ (request.path.parent / "client", "/config"), (report_dir, "/reports"), ], ) client.wait() finally: client.terminate() client.wait() autobahn_container.stop() results = get_test_results(report_dir / "clients", "aiohttp") xfail = { "3.4": "Actual events match at least one expected.", "7.9.5": "The close code should have been 1002 or empty", "9.1.4": "Did not receive message within 100 seconds.", "9.1.5": "Did not receive message within 100 seconds.", "9.1.6": "Did not receive message within 100 seconds.", "9.2.4": "Did not receive message within 10 seconds.", "9.2.5": "Did not receive message within 100 seconds.", "9.2.6": "Did not receive message within 100 seconds.", "9.3.1": "Did not receive message within 100 seconds.", "9.3.2": "Did not receive message within 100 seconds.", "9.3.3": "Did not receive message within 100 seconds.", "9.3.4": "Did not receive message within 100 seconds.", "9.3.5": "Did not receive message within 100 seconds.", "9.3.6": "Did not receive message within 100 seconds.", "9.3.7": "Did not receive message within 100 seconds.", "9.3.8": "Did not receive message within 100 seconds.", "9.3.9": "Did not receive message within 100 seconds.", "9.4.1": "Did not receive message within 100 seconds.", "9.4.2": "Did not receive message within 100 seconds.", "9.4.3": "Did not receive message within 100 seconds.", "9.4.4": "Did not receive message within 100 seconds.", "9.4.5": "Did not receive message within 100 seconds.", "9.4.6": "Did not receive message within 100 seconds.", "9.4.7": "Did not receive message within 100 seconds.", "9.4.8": "Did not receive message within 100 seconds.", "9.4.9": "Did not receive message within 100 seconds.", } assert not process_xfail(results, xfail) @pytest.mark.autobahn def test_server(report_dir: Path, request: pytest.FixtureRequest) -> None: server = subprocess.Popen( ("coverage", "run", "-a", "tests/autobahn/server/server.py") ) try: docker.run( image="autobahn-testsuite", name="autobahn", remove=True, volumes=[ (request.path.parent / "server", "/config"), (report_dir, "/reports"), ], networks=("host",), command=( "wait-for-it", "-s", "localhost:9001", "--", "wstest", "--mode", "fuzzingclient", "--spec", "/config/fuzzingclient.json", ), ) finally: server.terminate() server.wait() results = get_test_results(report_dir / "servers", "AutobahnServer") xfail = { "7.9.5": "The close code should have been 1002 or empty", "9.1.4": "Did not receive message within 100 seconds.", "9.1.5": "Did not receive message within 100 seconds.", "9.1.6": "Did not receive message within 100 seconds.", "9.2.4": "Did not receive message within 10 seconds.", "9.2.5": "Did not receive message within 100 seconds.", "9.2.6": "Did not receive message within 100 seconds.", "9.3.1": "Did not receive message within 100 seconds.", "9.3.2": "Did not receive message within 100 seconds.", "9.3.3": "Did not receive message within 100 seconds.", "9.3.4": "Did not receive message within 100 seconds.", "9.3.5": "Did not receive message within 100 seconds.", "9.3.6": "Did not receive message within 100 seconds.", "9.3.7": "Did not receive message within 100 seconds.", "9.3.8": "Did not receive message within 100 seconds.", "9.3.9": "Did not receive message within 100 seconds.", "9.4.1": "Did not receive message within 100 seconds.", "9.4.2": "Did not receive message within 100 seconds.", "9.4.3": "Did not receive message within 100 seconds.", "9.4.4": "Did not receive message within 100 seconds.", "9.4.5": "Did not receive message within 100 seconds.", "9.4.6": "Did not receive message within 100 seconds.", "9.4.7": "Did not receive message within 100 seconds.", "9.4.8": "Did not receive message within 100 seconds.", "9.4.9": "Did not receive message within 100 seconds.", } assert not process_xfail(results, xfail) ================================================ FILE: tests/conftest.py ================================================ from __future__ import annotations # TODO(PY311): Remove import asyncio import base64 import os import platform import socket import ssl import sys import time from collections.abc import AsyncIterator, Callable, Iterator from concurrent.futures import Future, ThreadPoolExecutor from hashlib import md5, sha1, sha256 from http.cookies import BaseCookie from pathlib import Path from tempfile import TemporaryDirectory from typing import Any from unittest import mock from uuid import uuid4 import pytest import trustme from multidict import CIMultiDict from yarl import URL try: from blockbuster import blockbuster_ctx HAS_BLOCKBUSTER = True except ImportError: # For downstreams only # pragma: no cover HAS_BLOCKBUSTER = False from aiohttp import payload from aiohttp.client import ClientSession from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequest, ClientRequestArgs, ClientResponse from aiohttp.compression_utils import ZLibBackend, ZLibBackendProtocol, set_zlib_backend from aiohttp.helpers import TimerNoop from aiohttp.http import WS_KEY, HttpVersion11 from aiohttp.test_utils import get_unused_port_socket, loop_context def pytest_configure(config: pytest.Config) -> None: # On Windows with Python 3.10/3.11, proxy.py's threaded mode can leave # sockets not fully released by the time pytest's unraisableexception # plugin collects warnings during teardown. Suppress these warnings # since they are not actionable and only affect older Python versions. if os.name == "nt" and sys.version_info < (3, 12): config.addinivalue_line( "filterwarnings", "ignore:Exception ignored in.*socket.*:pytest.PytestUnraisableExceptionWarning", ) try: if sys.platform == "win32": import winloop as uvloop else: import uvloop except ImportError: uvloop = None # type: ignore[assignment] if sys.version_info >= (3, 11): from typing import Unpack else: from typing import Any as Unpack pytest_plugins = ("aiohttp.pytest_plugin", "pytester") IS_HPUX = sys.platform.startswith("hp-ux") IS_LINUX = sys.platform.startswith("linux") @pytest.fixture(autouse=HAS_BLOCKBUSTER) def blockbuster(request: pytest.FixtureRequest) -> Iterator[None]: # Allow selectively disabling blockbuster for specific tests # using the @pytest.mark.skip_blockbuster marker. if "skip_blockbuster" in request.node.keywords: yield return # No blockbuster for benchmark tests. node = request.node.parent while node: if node.name.startswith("test_benchmarks"): yield return node = node.parent with blockbuster_ctx( "aiohttp", excluded_modules=["aiohttp.pytest_plugin", "aiohttp.test_utils"] ) as bb: for func in [ "os.getcwd", "os.readlink", "os.stat", "os.path.abspath", "os.path.samestat", ]: bb.functions[func].can_block_in( "aiohttp/web_urldispatcher.py", "add_static" ) # Note: coverage.py uses locking internally which can cause false positives # in blockbuster when it instruments code. This is particularly problematic # on Windows where it can lead to flaky test failures. # Additionally, we're not particularly worried about threading.Lock.acquire happening # by accident in this codebase as we primarily use asyncio.Lock for # synchronization in async code. # Allow lock.acquire calls to prevent these false positives bb.functions["threading.Lock.acquire"].deactivate() yield @pytest.fixture def tls_certificate_authority() -> trustme.CA: return trustme.CA() @pytest.fixture def tls_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert: return tls_certificate_authority.issue_cert( "localhost", "xn--prklad-4va.localhost", "127.0.0.1", "::1", ) @pytest.fixture def ssl_ctx(tls_certificate: trustme.LeafCert) -> ssl.SSLContext: ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) tls_certificate.configure_cert(ssl_ctx) return ssl_ctx @pytest.fixture def client_ssl_ctx(tls_certificate_authority: trustme.CA) -> ssl.SSLContext: ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) tls_certificate_authority.configure_trust(ssl_ctx) return ssl_ctx @pytest.fixture def tls_ca_certificate_pem_path(tls_certificate_authority: trustme.CA) -> Iterator[str]: with tls_certificate_authority.cert_pem.tempfile() as ca_cert_pem: yield ca_cert_pem @pytest.fixture def tls_certificate_pem_path(tls_certificate: trustme.LeafCert) -> Iterator[str]: with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem: yield cert_pem @pytest.fixture def tls_certificate_pem_bytes(tls_certificate: trustme.LeafCert) -> bytes: return tls_certificate.cert_chain_pems[0].bytes() @pytest.fixture def tls_certificate_fingerprint_sha256(tls_certificate_pem_bytes: bytes) -> bytes: tls_cert_der = ssl.PEM_cert_to_DER_cert(tls_certificate_pem_bytes.decode()) return sha256(tls_cert_der).digest() @pytest.fixture def pipe_name() -> str: name = rf"\\.\pipe\{uuid4().hex}" return name @pytest.fixture def create_mocked_conn( loop: asyncio.AbstractEventLoop, ) -> Iterator[Callable[[], ResponseHandler]]: def _proto_factory() -> Any: proto = mock.create_autospec(ResponseHandler, instance=True) proto.closed = loop.create_future() proto.closed.set_result(None) return proto yield _proto_factory @pytest.fixture def unix_sockname( tmp_path: Path, tmp_path_factory: pytest.TempPathFactory ) -> Iterator[str]: # Generate an fs path to the UNIX domain socket for testing. # N.B. Different OS kernels have different fs path length limitations # for it. For Linux, it's 108, for HP-UX it's 92 (or higher) depending # on its version. For most of the BSDs (Open, Free, macOS) it's # mostly 104 but sometimes it can be down to 100. # Ref: https://github.com/aio-libs/aiohttp/issues/3572 if not hasattr(socket, "AF_UNIX"): pytest.skip("requires UNIX sockets") max_sock_len = 92 if IS_HPUX else 108 if IS_LINUX else 100 # Amount of bytes allocated for the UNIX socket path by OS kernel. # Ref: https://unix.stackexchange.com/a/367012/27133 sock_file_name = "unix.sock" root_tmp_dir = Path("/tmp").resolve() os_tmp_dir = Path(os.getenv("TMPDIR", "/tmp")).resolve() original_base_tmp_path = Path( str(tmp_path_factory.getbasetemp()), ).resolve() original_base_tmp_path_hash = md5( str(original_base_tmp_path).encode(), ).hexdigest() def make_tmp_dir(base_tmp_dir: Path) -> TemporaryDirectory[str]: return TemporaryDirectory( dir=str(base_tmp_dir), prefix="pt-", suffix=f"-{original_base_tmp_path_hash!s}", ) def assert_sock_fits(sock_path: str) -> None: sock_path_len = len(sock_path.encode()) # exit-check to verify that it's correct and simplify debugging # in the future assert sock_path_len <= max_sock_len, ( "Suggested UNIX socket ({sock_path}) is {sock_path_len} bytes " "long but the current kernel only has {max_sock_len} bytes " "allocated to hold it so it must be shorter. " "See https://github.com/aio-libs/aiohttp/issues/3572 " "for more info." ).format_map(locals()) paths = original_base_tmp_path, os_tmp_dir, root_tmp_dir unique_paths = [p for n, p in enumerate(paths) if p not in paths[:n]] paths_num = len(unique_paths) for num, tmp_dir_path in enumerate(paths, 1): # pragma: no branch with make_tmp_dir(tmp_dir_path) as tmps: tmpd = Path(tmps).resolve() sock_path = str(tmpd / sock_file_name) sock_path_len = len(sock_path.encode()) if num >= paths_num: # exit-check to verify that it's correct and simplify # debugging in the future assert_sock_fits(sock_path) if sock_path_len <= max_sock_len: yield sock_path return @pytest.fixture async def event_loop(loop: asyncio.AbstractEventLoop) -> asyncio.AbstractEventLoop: return asyncio.get_running_loop() @pytest.fixture def selector_loop() -> Iterator[asyncio.AbstractEventLoop]: factory = asyncio.SelectorEventLoop with loop_context(factory) as _loop: asyncio.set_event_loop(_loop) yield _loop @pytest.fixture def uvloop_loop() -> Iterator[asyncio.AbstractEventLoop]: if uvloop is None: pytest.skip("uvloop is not installed") factory = uvloop.new_event_loop with loop_context(factory) as _loop: asyncio.set_event_loop(_loop) yield _loop @pytest.fixture def netrc_contents( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, request: pytest.FixtureRequest, ) -> Path: """ Prepare :file:`.netrc` with given contents. Monkey-patches :envvar:`NETRC` to point to created file. """ netrc_contents = getattr(request, "param", None) netrc_file_path = tmp_path / ".netrc" if netrc_contents is not None: netrc_file_path.write_text(netrc_contents) monkeypatch.setenv("NETRC", str(netrc_file_path)) return netrc_file_path @pytest.fixture def netrc_default_contents(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path: """Create a temporary netrc file with default test credentials and set NETRC env var.""" netrc_file = tmp_path / ".netrc" netrc_file.write_text("default login netrc_user password netrc_pass\n") monkeypatch.setenv("NETRC", str(netrc_file)) return netrc_file @pytest.fixture def no_netrc(monkeypatch: pytest.MonkeyPatch) -> None: """Ensure NETRC environment variable is not set.""" monkeypatch.delenv("NETRC", raising=False) @pytest.fixture def netrc_other_host(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path: """Create a temporary netrc file with credentials for a different host and set NETRC env var.""" netrc_file = tmp_path / ".netrc" netrc_file.write_text("machine other.example.com login user password pass\n") monkeypatch.setenv("NETRC", str(netrc_file)) return netrc_file @pytest.fixture def netrc_home_directory(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path: """Create a netrc file in a mocked home directory without setting NETRC env var.""" home_dir = tmp_path / "home" home_dir.mkdir() netrc_filename = "_netrc" if platform.system() == "Windows" else ".netrc" netrc_file = home_dir / netrc_filename netrc_file.write_text("default login netrc_user password netrc_pass\n") home_env_var = "USERPROFILE" if platform.system() == "Windows" else "HOME" monkeypatch.setenv(home_env_var, str(home_dir)) # Ensure NETRC env var is not set monkeypatch.delenv("NETRC", raising=False) return netrc_file @pytest.fixture def start_connection() -> Iterator[mock.Mock]: with mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, return_value=mock.create_autospec(socket.socket, spec_set=True, instance=True), ) as start_connection_mock: yield start_connection_mock @pytest.fixture def key_data() -> bytes: return os.urandom(16) @pytest.fixture def key(key_data: bytes) -> bytes: return base64.b64encode(key_data) @pytest.fixture def ws_key(key: bytes) -> str: return base64.b64encode(sha1(key + WS_KEY).digest()).decode() @pytest.fixture def enable_cleanup_closed() -> Iterator[None]: """Fixture to override the NEEDS_CLEANUP_CLOSED flag. On Python 3.12.7+ and 3.13.1+ enable_cleanup_closed is not needed, however we still want to test that it works. """ with mock.patch("aiohttp.connector.NEEDS_CLEANUP_CLOSED", True): yield @pytest.fixture def unused_port_socket() -> Iterator[socket.socket]: """Return a socket that is unused on the current host. Unlike aiohttp_used_port, the socket is yielded so there is no race condition between checking if the port is in use and binding to it later in the test. """ s = get_unused_port_socket("127.0.0.1") try: yield s finally: s.close() @pytest.fixture(params=["zlib", "zlib_ng.zlib_ng", "isal.isal_zlib"]) def parametrize_zlib_backend( request: pytest.FixtureRequest, ) -> Iterator[None]: original_backend: ZLibBackendProtocol = ZLibBackend._zlib_backend backend = pytest.importorskip(request.param) set_zlib_backend(backend) yield set_zlib_backend(original_backend) @pytest.fixture() async def cleanup_payload_pending_file_closes( loop: asyncio.AbstractEventLoop, ) -> AsyncIterator[None]: """Ensure all pending file close operations complete during test teardown.""" yield if payload._CLOSE_FUTURES: # Only wait for futures from the current loop loop_futures = [f for f in payload._CLOSE_FUTURES if f.get_loop() is loop] if loop_futures: await asyncio.gather(*loop_futures, return_exceptions=True) @pytest.fixture async def make_client_request( loop: asyncio.AbstractEventLoop, ) -> AsyncIterator[Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest]]: """Fixture to help creating test ClientRequest objects with defaults.""" requests: list[ClientRequest] = [] sessions: list[ClientSession] = [] def maker( method: str, url: URL, **kwargs: Unpack[ClientRequestArgs] ) -> ClientRequest: session = ClientSession() sessions.append(session) default_args: ClientRequestArgs = { "loop": loop, "params": {}, "headers": CIMultiDict[str](), "skip_auto_headers": None, "data": None, "cookies": BaseCookie[str](), "auth": None, "version": HttpVersion11, "compress": False, "chunked": None, "expect100": False, "response_class": ClientResponse, "proxy": None, "proxy_auth": None, "timer": TimerNoop(), "session": session, "ssl": True, "proxy_headers": None, "traces": [], "trust_env": False, "server_hostname": None, } request = ClientRequest(method, url, **(default_args | kwargs)) requests.append(request) return request yield maker await asyncio.gather( *(request._close() for request in requests), *(session.close() for session in sessions), ) @pytest.fixture def slow_executor() -> Iterator[ThreadPoolExecutor]: """Executor that adds delay to simulate slow operations. Useful for testing cancellation and race conditions in compression tests. """ class SlowExecutor(ThreadPoolExecutor): """Executor that adds delay to operations.""" def submit( self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any ) -> Future[Any]: def slow_fn(*args: Any, **kwargs: Any) -> Any: time.sleep(0.05) # Add delay to simulate slow operation return fn(*args, **kwargs) return super().submit(slow_fn, *args, **kwargs) executor = SlowExecutor(max_workers=10) yield executor executor.shutdown(wait=True) ================================================ FILE: tests/data.unknown_mime_type ================================================ file content ================================================ FILE: tests/data.zero_bytes ================================================ ================================================ FILE: tests/github-urls.json ================================================ [ "/", "/advisories", "/advisories/{ghsa_id}", "/app", "/app-manifests/{code}/conversions", "/app/hook/config", "/app/hook/deliveries", "/app/hook/deliveries/{delivery_id}", "/app/hook/deliveries/{delivery_id}/attempts", "/app/installation-requests", "/app/installations", "/app/installations/{installation_id}", "/app/installations/{installation_id}/access_tokens", "/app/installations/{installation_id}/suspended", "/applications/{client_id}/grant", "/applications/{client_id}/token", "/applications/{client_id}/token/scoped", "/apps/{app_slug}", "/assignments/{assignment_id}", "/assignments/{assignment_id}/accepted_assignments", "/assignments/{assignment_id}/grades", "/classrooms", "/classrooms/{classroom_id}", "/classrooms/{classroom_id}/assignments", "/codes_of_conduct", "/codes_of_conduct/{key}", "/emojis", "/enterprises/{enterprise}/copilot/billing/seats", "/enterprises/{enterprise}/copilot/metrics", "/enterprises/{enterprise}/copilot/usage", "/enterprises/{enterprise}/dependabot/alerts", "/enterprises/{enterprise}/secret-scanning/alerts", "/enterprises/{enterprise}/team/{team_slug}/copilot/metrics", "/enterprises/{enterprise}/team/{team_slug}/copilot/usage", "/events", "/feeds", "/gists", "/gists/public", "/gists/starred", "/gists/{gist_id}", "/gists/{gist_id}/comments", "/gists/{gist_id}/comments/{comment_id}", "/gists/{gist_id}/commits", "/gists/{gist_id}/forks", "/gists/{gist_id}/star", "/gists/{gist_id}/{sha}", "/gitignore/templates", "/gitignore/templates/{name}", "/installation/repositories", "/installation/token", "/issues", "/licenses", "/licenses/{license}", "/markdown", "/markdown/raw", "/marketplace_listing/accounts/{account_id}", "/marketplace_listing/plans", "/marketplace_listing/plans/{plan_id}/accounts", "/marketplace_listing/stubbed/accounts/{account_id}", "/marketplace_listing/stubbed/plans", "/marketplace_listing/stubbed/plans/{plan_id}/accounts", "/meta", "/networks/{owner}/{repo}/events", "/notifications", "/notifications/threads/{thread_id}", "/notifications/threads/{thread_id}/subscription", "/octocat", "/organizations", "/orgs/{org}", "/orgs/{org}/actions/cache/usage", "/orgs/{org}/actions/cache/usage-by-repository", "/orgs/{org}/actions/oidc/customization/sub", "/orgs/{org}/actions/permissions", "/orgs/{org}/actions/permissions/repositories", "/orgs/{org}/actions/permissions/repositories/{repository_id}", "/orgs/{org}/actions/permissions/selected-actions", "/orgs/{org}/actions/permissions/workflow", "/orgs/{org}/actions/runner-groups", "/orgs/{org}/actions/runner-groups/{runner_group_id}", "/orgs/{org}/actions/runner-groups/{runner_group_id}/repositories", "/orgs/{org}/actions/runner-groups/{runner_group_id}/repositories/{repository_id}", "/orgs/{org}/actions/runner-groups/{runner_group_id}/runners", "/orgs/{org}/actions/runner-groups/{runner_group_id}/runners/{runner_id}", "/orgs/{org}/actions/runners", "/orgs/{org}/actions/runners/downloads", "/orgs/{org}/actions/runners/generate-jitconfig", "/orgs/{org}/actions/runners/registration-token", "/orgs/{org}/actions/runners/remove-token", "/orgs/{org}/actions/runners/{runner_id}", "/orgs/{org}/actions/runners/{runner_id}/labels", "/orgs/{org}/actions/runners/{runner_id}/labels/{name}", "/orgs/{org}/actions/secrets", "/orgs/{org}/actions/secrets/public-key", "/orgs/{org}/actions/secrets/{secret_name}", "/orgs/{org}/actions/secrets/{secret_name}/repositories", "/orgs/{org}/actions/secrets/{secret_name}/repositories/{repository_id}", "/orgs/{org}/actions/variables", "/orgs/{org}/actions/variables/{name}", "/orgs/{org}/actions/variables/{name}/repositories", "/orgs/{org}/actions/variables/{name}/repositories/{repository_id}", "/orgs/{org}/attestations/{subject_digest}", "/orgs/{org}/blocks", "/orgs/{org}/blocks/{username}", "/orgs/{org}/code-scanning/alerts", "/orgs/{org}/code-security/configurations", "/orgs/{org}/code-security/configurations/defaults", "/orgs/{org}/code-security/configurations/detach", "/orgs/{org}/code-security/configurations/{configuration_id}", "/orgs/{org}/code-security/configurations/{configuration_id}/attach", "/orgs/{org}/code-security/configurations/{configuration_id}/defaults", "/orgs/{org}/code-security/configurations/{configuration_id}/repositories", "/orgs/{org}/codespaces", "/orgs/{org}/codespaces/access", "/orgs/{org}/codespaces/access/selected_users", "/orgs/{org}/codespaces/secrets", "/orgs/{org}/codespaces/secrets/public-key", "/orgs/{org}/codespaces/secrets/{secret_name}", "/orgs/{org}/codespaces/secrets/{secret_name}/repositories", "/orgs/{org}/codespaces/secrets/{secret_name}/repositories/{repository_id}", "/orgs/{org}/copilot/billing", "/orgs/{org}/copilot/billing/seats", "/orgs/{org}/copilot/billing/selected_teams", "/orgs/{org}/copilot/billing/selected_users", "/orgs/{org}/copilot/metrics", "/orgs/{org}/copilot/usage", "/orgs/{org}/dependabot/alerts", "/orgs/{org}/dependabot/secrets", "/orgs/{org}/dependabot/secrets/public-key", "/orgs/{org}/dependabot/secrets/{secret_name}", "/orgs/{org}/dependabot/secrets/{secret_name}/repositories", "/orgs/{org}/dependabot/secrets/{secret_name}/repositories/{repository_id}", "/orgs/{org}/docker/conflicts", "/orgs/{org}/events", "/orgs/{org}/failed_invitations", "/orgs/{org}/hooks", "/orgs/{org}/hooks/{hook_id}", "/orgs/{org}/hooks/{hook_id}/config", "/orgs/{org}/hooks/{hook_id}/deliveries", "/orgs/{org}/hooks/{hook_id}/deliveries/{delivery_id}", "/orgs/{org}/hooks/{hook_id}/deliveries/{delivery_id}/attempts", "/orgs/{org}/hooks/{hook_id}/pings", "/orgs/{org}/insights/api/route-stats/{actor_type}/{actor_id}", "/orgs/{org}/insights/api/subject-stats", "/orgs/{org}/insights/api/summary-stats", "/orgs/{org}/insights/api/summary-stats/users/{user_id}", "/orgs/{org}/insights/api/summary-stats/{actor_type}/{actor_id}", "/orgs/{org}/insights/api/time-stats", "/orgs/{org}/insights/api/time-stats/users/{user_id}", "/orgs/{org}/insights/api/time-stats/{actor_type}/{actor_id}", "/orgs/{org}/insights/api/user-stats/{user_id}", "/orgs/{org}/installation", "/orgs/{org}/installations", "/orgs/{org}/interaction-limits", "/orgs/{org}/invitations", "/orgs/{org}/invitations/{invitation_id}", "/orgs/{org}/invitations/{invitation_id}/teams", "/orgs/{org}/issues", "/orgs/{org}/members", "/orgs/{org}/members/{username}", "/orgs/{org}/members/{username}/codespaces", "/orgs/{org}/members/{username}/codespaces/{codespace_name}", "/orgs/{org}/members/{username}/codespaces/{codespace_name}/stop", "/orgs/{org}/members/{username}/copilot", "/orgs/{org}/memberships/{username}", "/orgs/{org}/migrations", "/orgs/{org}/migrations/{migration_id}", "/orgs/{org}/migrations/{migration_id}/archive", "/orgs/{org}/migrations/{migration_id}/repos/{repo_name}/lock", "/orgs/{org}/migrations/{migration_id}/repositories", "/orgs/{org}/organization-roles", "/orgs/{org}/organization-roles/teams/{team_slug}", "/orgs/{org}/organization-roles/teams/{team_slug}/{role_id}", "/orgs/{org}/organization-roles/users/{username}", "/orgs/{org}/organization-roles/users/{username}/{role_id}", "/orgs/{org}/organization-roles/{role_id}", "/orgs/{org}/organization-roles/{role_id}/teams", "/orgs/{org}/organization-roles/{role_id}/users", "/orgs/{org}/outside_collaborators", "/orgs/{org}/outside_collaborators/{username}", "/orgs/{org}/packages", "/orgs/{org}/packages/{package_type}/{package_name}", "/orgs/{org}/packages/{package_type}/{package_name}/restore", "/orgs/{org}/packages/{package_type}/{package_name}/versions", "/orgs/{org}/packages/{package_type}/{package_name}/versions/{package_version_id}", "/orgs/{org}/packages/{package_type}/{package_name}/versions/{package_version_id}/restore", "/orgs/{org}/personal-access-token-requests", "/orgs/{org}/personal-access-token-requests/{pat_request_id}", "/orgs/{org}/personal-access-token-requests/{pat_request_id}/repositories", "/orgs/{org}/personal-access-tokens", "/orgs/{org}/personal-access-tokens/{pat_id}", "/orgs/{org}/personal-access-tokens/{pat_id}/repositories", "/orgs/{org}/projects", "/orgs/{org}/properties/schema", "/orgs/{org}/properties/schema/{custom_property_name}", "/orgs/{org}/properties/values", "/orgs/{org}/public_members", "/orgs/{org}/public_members/{username}", "/orgs/{org}/repos", "/orgs/{org}/rulesets", "/orgs/{org}/rulesets/rule-suites", "/orgs/{org}/rulesets/rule-suites/{rule_suite_id}", "/orgs/{org}/rulesets/{ruleset_id}", "/orgs/{org}/secret-scanning/alerts", "/orgs/{org}/security-advisories", "/orgs/{org}/security-managers", "/orgs/{org}/security-managers/teams/{team_slug}", "/orgs/{org}/settings/billing/actions", "/orgs/{org}/settings/billing/packages", "/orgs/{org}/settings/billing/shared-storage", "/orgs/{org}/team/{team_slug}/copilot/metrics", "/orgs/{org}/team/{team_slug}/copilot/usage", "/orgs/{org}/teams", "/orgs/{org}/teams/{team_slug}", "/orgs/{org}/teams/{team_slug}/discussions", "/orgs/{org}/teams/{team_slug}/discussions/{discussion_number}", "/orgs/{org}/teams/{team_slug}/discussions/{discussion_number}/comments", "/orgs/{org}/teams/{team_slug}/discussions/{discussion_number}/comments/{comment_number}", "/orgs/{org}/teams/{team_slug}/discussions/{discussion_number}/comments/{comment_number}/reactions", "/orgs/{org}/teams/{team_slug}/discussions/{discussion_number}/comments/{comment_number}/reactions/{reaction_id}", "/orgs/{org}/teams/{team_slug}/discussions/{discussion_number}/reactions", "/orgs/{org}/teams/{team_slug}/discussions/{discussion_number}/reactions/{reaction_id}", "/orgs/{org}/teams/{team_slug}/invitations", "/orgs/{org}/teams/{team_slug}/members", "/orgs/{org}/teams/{team_slug}/memberships/{username}", "/orgs/{org}/teams/{team_slug}/projects", "/orgs/{org}/teams/{team_slug}/projects/{project_id}", "/orgs/{org}/teams/{team_slug}/repos", "/orgs/{org}/teams/{team_slug}/repos/{owner}/{repo}", "/orgs/{org}/teams/{team_slug}/teams", "/orgs/{org}/{security_product}/{enablement}", "/projects/columns/cards/{card_id}", "/projects/columns/cards/{card_id}/moves", "/projects/columns/{column_id}", "/projects/columns/{column_id}/cards", "/projects/columns/{column_id}/moves", "/projects/{project_id}", "/projects/{project_id}/collaborators", "/projects/{project_id}/collaborators/{username}", "/projects/{project_id}/collaborators/{username}/permission", "/projects/{project_id}/columns", "/rate_limit", "/repos/{owner}/{repo}", "/repos/{owner}/{repo}/actions/artifacts", "/repos/{owner}/{repo}/actions/artifacts/{artifact_id}", "/repos/{owner}/{repo}/actions/artifacts/{artifact_id}/{archive_format}", "/repos/{owner}/{repo}/actions/cache/usage", "/repos/{owner}/{repo}/actions/caches", "/repos/{owner}/{repo}/actions/caches/{cache_id}", "/repos/{owner}/{repo}/actions/jobs/{job_id}", "/repos/{owner}/{repo}/actions/jobs/{job_id}/logs", "/repos/{owner}/{repo}/actions/jobs/{job_id}/rerun", "/repos/{owner}/{repo}/actions/oidc/customization/sub", "/repos/{owner}/{repo}/actions/organization-secrets", "/repos/{owner}/{repo}/actions/organization-variables", "/repos/{owner}/{repo}/actions/permissions", "/repos/{owner}/{repo}/actions/permissions/access", "/repos/{owner}/{repo}/actions/permissions/selected-actions", "/repos/{owner}/{repo}/actions/permissions/workflow", "/repos/{owner}/{repo}/actions/runners", "/repos/{owner}/{repo}/actions/runners/downloads", "/repos/{owner}/{repo}/actions/runners/generate-jitconfig", "/repos/{owner}/{repo}/actions/runners/registration-token", "/repos/{owner}/{repo}/actions/runners/remove-token", "/repos/{owner}/{repo}/actions/runners/{runner_id}", "/repos/{owner}/{repo}/actions/runners/{runner_id}/labels", "/repos/{owner}/{repo}/actions/runners/{runner_id}/labels/{name}", "/repos/{owner}/{repo}/actions/runs", "/repos/{owner}/{repo}/actions/runs/{run_id}", "/repos/{owner}/{repo}/actions/runs/{run_id}/approvals", "/repos/{owner}/{repo}/actions/runs/{run_id}/approve", "/repos/{owner}/{repo}/actions/runs/{run_id}/artifacts", "/repos/{owner}/{repo}/actions/runs/{run_id}/attempts/{attempt_number}", "/repos/{owner}/{repo}/actions/runs/{run_id}/attempts/{attempt_number}/jobs", "/repos/{owner}/{repo}/actions/runs/{run_id}/attempts/{attempt_number}/logs", "/repos/{owner}/{repo}/actions/runs/{run_id}/cancel", "/repos/{owner}/{repo}/actions/runs/{run_id}/deployment_protection_rule", "/repos/{owner}/{repo}/actions/runs/{run_id}/force-cancel", "/repos/{owner}/{repo}/actions/runs/{run_id}/jobs", "/repos/{owner}/{repo}/actions/runs/{run_id}/logs", "/repos/{owner}/{repo}/actions/runs/{run_id}/pending_deployments", "/repos/{owner}/{repo}/actions/runs/{run_id}/rerun", "/repos/{owner}/{repo}/actions/runs/{run_id}/rerun-failed-jobs", "/repos/{owner}/{repo}/actions/runs/{run_id}/timing", "/repos/{owner}/{repo}/actions/secrets", "/repos/{owner}/{repo}/actions/secrets/public-key", "/repos/{owner}/{repo}/actions/secrets/{secret_name}", "/repos/{owner}/{repo}/actions/variables", "/repos/{owner}/{repo}/actions/variables/{name}", "/repos/{owner}/{repo}/actions/workflows", "/repos/{owner}/{repo}/actions/workflows/{workflow_id}", "/repos/{owner}/{repo}/actions/workflows/{workflow_id}/disable", "/repos/{owner}/{repo}/actions/workflows/{workflow_id}/dispatches", "/repos/{owner}/{repo}/actions/workflows/{workflow_id}/enable", "/repos/{owner}/{repo}/actions/workflows/{workflow_id}/runs", "/repos/{owner}/{repo}/actions/workflows/{workflow_id}/timing", "/repos/{owner}/{repo}/activity", "/repos/{owner}/{repo}/assignees", "/repos/{owner}/{repo}/assignees/{assignee}", "/repos/{owner}/{repo}/attestations", "/repos/{owner}/{repo}/attestations/{subject_digest}", "/repos/{owner}/{repo}/autolinks", "/repos/{owner}/{repo}/autolinks/{autolink_id}", "/repos/{owner}/{repo}/automated-security-fixes", "/repos/{owner}/{repo}/branches", "/repos/{owner}/{repo}/branches/{branch}", "/repos/{owner}/{repo}/branches/{branch}/protection", "/repos/{owner}/{repo}/branches/{branch}/protection/enforce_admins", "/repos/{owner}/{repo}/branches/{branch}/protection/required_pull_request_reviews", "/repos/{owner}/{repo}/branches/{branch}/protection/required_signatures", "/repos/{owner}/{repo}/branches/{branch}/protection/required_status_checks", "/repos/{owner}/{repo}/branches/{branch}/protection/required_status_checks/contexts", "/repos/{owner}/{repo}/branches/{branch}/protection/restrictions", "/repos/{owner}/{repo}/branches/{branch}/protection/restrictions/apps", "/repos/{owner}/{repo}/branches/{branch}/protection/restrictions/teams", "/repos/{owner}/{repo}/branches/{branch}/protection/restrictions/users", "/repos/{owner}/{repo}/branches/{branch}/rename", "/repos/{owner}/{repo}/check-runs", "/repos/{owner}/{repo}/check-runs/{check_run_id}", "/repos/{owner}/{repo}/check-runs/{check_run_id}/annotations", "/repos/{owner}/{repo}/check-runs/{check_run_id}/rerequest", "/repos/{owner}/{repo}/check-suites", "/repos/{owner}/{repo}/check-suites/preferences", "/repos/{owner}/{repo}/check-suites/{check_suite_id}", "/repos/{owner}/{repo}/check-suites/{check_suite_id}/check-runs", "/repos/{owner}/{repo}/check-suites/{check_suite_id}/rerequest", "/repos/{owner}/{repo}/code-scanning/alerts", "/repos/{owner}/{repo}/code-scanning/alerts/{alert_number}", "/repos/{owner}/{repo}/code-scanning/alerts/{alert_number}/instances", "/repos/{owner}/{repo}/code-scanning/analyses", "/repos/{owner}/{repo}/code-scanning/analyses/{analysis_id}", "/repos/{owner}/{repo}/code-scanning/codeql/databases", "/repos/{owner}/{repo}/code-scanning/codeql/databases/{language}", "/repos/{owner}/{repo}/code-scanning/codeql/variant-analyses", "/repos/{owner}/{repo}/code-scanning/codeql/variant-analyses/{codeql_variant_analysis_id}", "/repos/{owner}/{repo}/code-scanning/codeql/variant-analyses/{codeql_variant_analysis_id}/repos/{repo_owner}/{repo_name}", "/repos/{owner}/{repo}/code-scanning/default-setup", "/repos/{owner}/{repo}/code-scanning/sarifs", "/repos/{owner}/{repo}/code-scanning/sarifs/{sarif_id}", "/repos/{owner}/{repo}/code-security-configuration", "/repos/{owner}/{repo}/codeowners/errors", "/repos/{owner}/{repo}/codespaces", "/repos/{owner}/{repo}/codespaces/devcontainers", "/repos/{owner}/{repo}/codespaces/machines", "/repos/{owner}/{repo}/codespaces/new", "/repos/{owner}/{repo}/codespaces/permissions_check", "/repos/{owner}/{repo}/codespaces/secrets", "/repos/{owner}/{repo}/codespaces/secrets/public-key", "/repos/{owner}/{repo}/codespaces/secrets/{secret_name}", "/repos/{owner}/{repo}/collaborators", "/repos/{owner}/{repo}/collaborators/{username}", "/repos/{owner}/{repo}/collaborators/{username}/permission", "/repos/{owner}/{repo}/comments", "/repos/{owner}/{repo}/comments/{comment_id}", "/repos/{owner}/{repo}/comments/{comment_id}/reactions", "/repos/{owner}/{repo}/comments/{comment_id}/reactions/{reaction_id}", "/repos/{owner}/{repo}/commits", "/repos/{owner}/{repo}/commits/{commit_sha}/branches-where-head", "/repos/{owner}/{repo}/commits/{commit_sha}/comments", "/repos/{owner}/{repo}/commits/{commit_sha}/pulls", "/repos/{owner}/{repo}/commits/{ref}", "/repos/{owner}/{repo}/commits/{ref}/check-runs", "/repos/{owner}/{repo}/commits/{ref}/check-suites", "/repos/{owner}/{repo}/commits/{ref}/status", "/repos/{owner}/{repo}/commits/{ref}/statuses", "/repos/{owner}/{repo}/community/profile", "/repos/{owner}/{repo}/compare/{basehead}", "/repos/{owner}/{repo}/contents/{path}", "/repos/{owner}/{repo}/contributors", "/repos/{owner}/{repo}/dependabot/alerts", "/repos/{owner}/{repo}/dependabot/alerts/{alert_number}", "/repos/{owner}/{repo}/dependabot/secrets", "/repos/{owner}/{repo}/dependabot/secrets/public-key", "/repos/{owner}/{repo}/dependabot/secrets/{secret_name}", "/repos/{owner}/{repo}/dependency-graph/compare/{basehead}", "/repos/{owner}/{repo}/dependency-graph/sbom", "/repos/{owner}/{repo}/dependency-graph/snapshots", "/repos/{owner}/{repo}/deployments", "/repos/{owner}/{repo}/deployments/{deployment_id}", "/repos/{owner}/{repo}/deployments/{deployment_id}/statuses", "/repos/{owner}/{repo}/deployments/{deployment_id}/statuses/{status_id}", "/repos/{owner}/{repo}/dispatches", "/repos/{owner}/{repo}/environments", "/repos/{owner}/{repo}/environments/{environment_name}", "/repos/{owner}/{repo}/environments/{environment_name}/deployment-branch-policies", "/repos/{owner}/{repo}/environments/{environment_name}/deployment-branch-policies/{branch_policy_id}", "/repos/{owner}/{repo}/environments/{environment_name}/deployment_protection_rules", "/repos/{owner}/{repo}/environments/{environment_name}/deployment_protection_rules/apps", "/repos/{owner}/{repo}/environments/{environment_name}/deployment_protection_rules/{protection_rule_id}", "/repos/{owner}/{repo}/environments/{environment_name}/secrets", "/repos/{owner}/{repo}/environments/{environment_name}/secrets/public-key", "/repos/{owner}/{repo}/environments/{environment_name}/secrets/{secret_name}", "/repos/{owner}/{repo}/environments/{environment_name}/variables", "/repos/{owner}/{repo}/environments/{environment_name}/variables/{name}", "/repos/{owner}/{repo}/events", "/repos/{owner}/{repo}/forks", "/repos/{owner}/{repo}/git/blobs", "/repos/{owner}/{repo}/git/blobs/{file_sha}", "/repos/{owner}/{repo}/git/commits", "/repos/{owner}/{repo}/git/commits/{commit_sha}", "/repos/{owner}/{repo}/git/matching-refs/{ref}", "/repos/{owner}/{repo}/git/ref/{ref}", "/repos/{owner}/{repo}/git/refs", "/repos/{owner}/{repo}/git/refs/{ref}", "/repos/{owner}/{repo}/git/tags", "/repos/{owner}/{repo}/git/tags/{tag_sha}", "/repos/{owner}/{repo}/git/trees", "/repos/{owner}/{repo}/git/trees/{tree_sha}", "/repos/{owner}/{repo}/hooks", "/repos/{owner}/{repo}/hooks/{hook_id}", "/repos/{owner}/{repo}/hooks/{hook_id}/config", "/repos/{owner}/{repo}/hooks/{hook_id}/deliveries", "/repos/{owner}/{repo}/hooks/{hook_id}/deliveries/{delivery_id}", "/repos/{owner}/{repo}/hooks/{hook_id}/deliveries/{delivery_id}/attempts", "/repos/{owner}/{repo}/hooks/{hook_id}/pings", "/repos/{owner}/{repo}/hooks/{hook_id}/tests", "/repos/{owner}/{repo}/import", "/repos/{owner}/{repo}/import/authors", "/repos/{owner}/{repo}/import/authors/{author_id}", "/repos/{owner}/{repo}/import/large_files", "/repos/{owner}/{repo}/import/lfs", "/repos/{owner}/{repo}/installation", "/repos/{owner}/{repo}/interaction-limits", "/repos/{owner}/{repo}/invitations", "/repos/{owner}/{repo}/invitations/{invitation_id}", "/repos/{owner}/{repo}/issues", "/repos/{owner}/{repo}/issues/comments", "/repos/{owner}/{repo}/issues/comments/{comment_id}", "/repos/{owner}/{repo}/issues/comments/{comment_id}/reactions", "/repos/{owner}/{repo}/issues/comments/{comment_id}/reactions/{reaction_id}", "/repos/{owner}/{repo}/issues/events", "/repos/{owner}/{repo}/issues/events/{event_id}", "/repos/{owner}/{repo}/issues/{issue_number}", "/repos/{owner}/{repo}/issues/{issue_number}/assignees", "/repos/{owner}/{repo}/issues/{issue_number}/assignees/{assignee}", "/repos/{owner}/{repo}/issues/{issue_number}/comments", "/repos/{owner}/{repo}/issues/{issue_number}/events", "/repos/{owner}/{repo}/issues/{issue_number}/labels", "/repos/{owner}/{repo}/issues/{issue_number}/labels/{name}", "/repos/{owner}/{repo}/issues/{issue_number}/lock", "/repos/{owner}/{repo}/issues/{issue_number}/reactions", "/repos/{owner}/{repo}/issues/{issue_number}/reactions/{reaction_id}", "/repos/{owner}/{repo}/issues/{issue_number}/timeline", "/repos/{owner}/{repo}/keys", "/repos/{owner}/{repo}/keys/{key_id}", "/repos/{owner}/{repo}/labels", "/repos/{owner}/{repo}/labels/{name}", "/repos/{owner}/{repo}/languages", "/repos/{owner}/{repo}/license", "/repos/{owner}/{repo}/merge-upstream", "/repos/{owner}/{repo}/merges", "/repos/{owner}/{repo}/milestones", "/repos/{owner}/{repo}/milestones/{milestone_number}", "/repos/{owner}/{repo}/milestones/{milestone_number}/labels", "/repos/{owner}/{repo}/notifications", "/repos/{owner}/{repo}/pages", "/repos/{owner}/{repo}/pages/builds", "/repos/{owner}/{repo}/pages/builds/latest", "/repos/{owner}/{repo}/pages/builds/{build_id}", "/repos/{owner}/{repo}/pages/deployments", "/repos/{owner}/{repo}/pages/deployments/{pages_deployment_id}", "/repos/{owner}/{repo}/pages/deployments/{pages_deployment_id}/cancel", "/repos/{owner}/{repo}/pages/health", "/repos/{owner}/{repo}/private-vulnerability-reporting", "/repos/{owner}/{repo}/projects", "/repos/{owner}/{repo}/properties/values", "/repos/{owner}/{repo}/pulls", "/repos/{owner}/{repo}/pulls/comments", "/repos/{owner}/{repo}/pulls/comments/{comment_id}", "/repos/{owner}/{repo}/pulls/comments/{comment_id}/reactions", "/repos/{owner}/{repo}/pulls/comments/{comment_id}/reactions/{reaction_id}", "/repos/{owner}/{repo}/pulls/{pull_number}", "/repos/{owner}/{repo}/pulls/{pull_number}/codespaces", "/repos/{owner}/{repo}/pulls/{pull_number}/comments", "/repos/{owner}/{repo}/pulls/{pull_number}/comments/{comment_id}/replies", "/repos/{owner}/{repo}/pulls/{pull_number}/commits", "/repos/{owner}/{repo}/pulls/{pull_number}/files", "/repos/{owner}/{repo}/pulls/{pull_number}/merge", "/repos/{owner}/{repo}/pulls/{pull_number}/requested_reviewers", "/repos/{owner}/{repo}/pulls/{pull_number}/reviews", "/repos/{owner}/{repo}/pulls/{pull_number}/reviews/{review_id}", "/repos/{owner}/{repo}/pulls/{pull_number}/reviews/{review_id}/comments", "/repos/{owner}/{repo}/pulls/{pull_number}/reviews/{review_id}/dismissals", "/repos/{owner}/{repo}/pulls/{pull_number}/reviews/{review_id}/events", "/repos/{owner}/{repo}/pulls/{pull_number}/update-branch", "/repos/{owner}/{repo}/readme", "/repos/{owner}/{repo}/readme/{dir}", "/repos/{owner}/{repo}/releases", "/repos/{owner}/{repo}/releases/assets/{asset_id}", "/repos/{owner}/{repo}/releases/generate-notes", "/repos/{owner}/{repo}/releases/latest", "/repos/{owner}/{repo}/releases/tags/{tag}", "/repos/{owner}/{repo}/releases/{release_id}", "/repos/{owner}/{repo}/releases/{release_id}/assets", "/repos/{owner}/{repo}/releases/{release_id}/reactions", "/repos/{owner}/{repo}/releases/{release_id}/reactions/{reaction_id}", "/repos/{owner}/{repo}/rules/branches/{branch}", "/repos/{owner}/{repo}/rulesets", "/repos/{owner}/{repo}/rulesets/rule-suites", "/repos/{owner}/{repo}/rulesets/rule-suites/{rule_suite_id}", "/repos/{owner}/{repo}/rulesets/{ruleset_id}", "/repos/{owner}/{repo}/secret-scanning/alerts", "/repos/{owner}/{repo}/secret-scanning/alerts/{alert_number}", "/repos/{owner}/{repo}/secret-scanning/alerts/{alert_number}/locations", "/repos/{owner}/{repo}/secret-scanning/push-protection-bypasses", "/repos/{owner}/{repo}/security-advisories", "/repos/{owner}/{repo}/security-advisories/reports", "/repos/{owner}/{repo}/security-advisories/{ghsa_id}", "/repos/{owner}/{repo}/security-advisories/{ghsa_id}/cve", "/repos/{owner}/{repo}/security-advisories/{ghsa_id}/forks", "/repos/{owner}/{repo}/stargazers", "/repos/{owner}/{repo}/stats/code_frequency", "/repos/{owner}/{repo}/stats/commit_activity", "/repos/{owner}/{repo}/stats/contributors", "/repos/{owner}/{repo}/stats/participation", "/repos/{owner}/{repo}/stats/punch_card", "/repos/{owner}/{repo}/statuses/{sha}", "/repos/{owner}/{repo}/subscribers", "/repos/{owner}/{repo}/subscription", "/repos/{owner}/{repo}/tags", "/repos/{owner}/{repo}/tags/protection", "/repos/{owner}/{repo}/tags/protection/{tag_protection_id}", "/repos/{owner}/{repo}/tarball/{ref}", "/repos/{owner}/{repo}/teams", "/repos/{owner}/{repo}/topics", "/repos/{owner}/{repo}/traffic/clones", "/repos/{owner}/{repo}/traffic/popular/paths", "/repos/{owner}/{repo}/traffic/popular/referrers", "/repos/{owner}/{repo}/traffic/views", "/repos/{owner}/{repo}/transfer", "/repos/{owner}/{repo}/vulnerability-alerts", "/repos/{owner}/{repo}/zipball/{ref}", "/repos/{template_owner}/{template_repo}/generate", "/repositories", "/search/code", "/search/commits", "/search/issues", "/search/labels", "/search/repositories", "/search/topics", "/search/users", "/teams/{team_id}", "/teams/{team_id}/discussions", "/teams/{team_id}/discussions/{discussion_number}", "/teams/{team_id}/discussions/{discussion_number}/comments", "/teams/{team_id}/discussions/{discussion_number}/comments/{comment_number}", "/teams/{team_id}/discussions/{discussion_number}/comments/{comment_number}/reactions", "/teams/{team_id}/discussions/{discussion_number}/reactions", "/teams/{team_id}/invitations", "/teams/{team_id}/members", "/teams/{team_id}/members/{username}", "/teams/{team_id}/memberships/{username}", "/teams/{team_id}/projects", "/teams/{team_id}/projects/{project_id}", "/teams/{team_id}/repos", "/teams/{team_id}/repos/{owner}/{repo}", "/teams/{team_id}/teams", "/user", "/user/blocks", "/user/blocks/{username}", "/user/codespaces", "/user/codespaces/secrets", "/user/codespaces/secrets/public-key", "/user/codespaces/secrets/{secret_name}", "/user/codespaces/secrets/{secret_name}/repositories", "/user/codespaces/secrets/{secret_name}/repositories/{repository_id}", "/user/codespaces/{codespace_name}", "/user/codespaces/{codespace_name}/exports", "/user/codespaces/{codespace_name}/exports/{export_id}", "/user/codespaces/{codespace_name}/machines", "/user/codespaces/{codespace_name}/publish", "/user/codespaces/{codespace_name}/start", "/user/codespaces/{codespace_name}/stop", "/user/docker/conflicts", "/user/email/visibility", "/user/emails", "/user/followers", "/user/following", "/user/following/{username}", "/user/gpg_keys", "/user/gpg_keys/{gpg_key_id}", "/user/installations", "/user/installations/{installation_id}/repositories", "/user/installations/{installation_id}/repositories/{repository_id}", "/user/interaction-limits", "/user/issues", "/user/keys", "/user/keys/{key_id}", "/user/marketplace_purchases", "/user/marketplace_purchases/stubbed", "/user/memberships/orgs", "/user/memberships/orgs/{org}", "/user/migrations", "/user/migrations/{migration_id}", "/user/migrations/{migration_id}/archive", "/user/migrations/{migration_id}/repos/{repo_name}/lock", "/user/migrations/{migration_id}/repositories", "/user/orgs", "/user/packages", "/user/packages/{package_type}/{package_name}", "/user/packages/{package_type}/{package_name}/restore", "/user/packages/{package_type}/{package_name}/versions", "/user/packages/{package_type}/{package_name}/versions/{package_version_id}", "/user/packages/{package_type}/{package_name}/versions/{package_version_id}/restore", "/user/projects", "/user/public_emails", "/user/repos", "/user/repository_invitations", "/user/repository_invitations/{invitation_id}", "/user/social_accounts", "/user/ssh_signing_keys", "/user/ssh_signing_keys/{ssh_signing_key_id}", "/user/starred", "/user/starred/{owner}/{repo}", "/user/subscriptions", "/user/teams", "/user/{account_id}", "/users", "/users/{username}", "/users/{username}/attestations/{subject_digest}", "/users/{username}/docker/conflicts", "/users/{username}/events", "/users/{username}/events/orgs/{org}", "/users/{username}/events/public", "/users/{username}/followers", "/users/{username}/following", "/users/{username}/following/{target_user}", "/users/{username}/gists", "/users/{username}/gpg_keys", "/users/{username}/hovercard", "/users/{username}/installation", "/users/{username}/keys", "/users/{username}/orgs", "/users/{username}/packages", "/users/{username}/packages/{package_type}/{package_name}", "/users/{username}/packages/{package_type}/{package_name}/restore", "/users/{username}/packages/{package_type}/{package_name}/versions", "/users/{username}/packages/{package_type}/{package_name}/versions/{package_version_id}", "/users/{username}/packages/{package_type}/{package_name}/versions/{package_version_id}/restore", "/users/{username}/projects", "/users/{username}/received_events", "/users/{username}/received_events/public", "/users/{username}/repos", "/users/{username}/settings/billing/actions", "/users/{username}/settings/billing/packages", "/users/{username}/settings/billing/shared-storage", "/users/{username}/social_accounts", "/users/{username}/ssh_signing_keys", "/users/{username}/starred", "/users/{username}/subscriptions", "/versions", "/zen" ] ================================================ FILE: tests/isolated/check_for_client_response_leak.py ================================================ import asyncio import contextlib import gc import socket import sys from aiohttp import ClientError, ClientSession, web from aiohttp.test_utils import REUSE_ADDRESS gc.set_debug(gc.DEBUG_LEAK) async def main() -> None: app = web.Application() async def stream_handler(request: web.Request) -> web.Response: assert request.transport is not None request.transport.close() # Forcefully closing connection return web.Response() app.router.add_get("/stream", stream_handler) with socket.create_server(("127.0.0.1", 0), reuse_port=REUSE_ADDRESS) as sock: port = sock.getsockname()[1] runner = web.AppRunner(app) await runner.setup() site = web.SockSite(runner, sock) await site.start() session = ClientSession() async def fetch_stream(url: str) -> None: """Fetch a stream and read a few bytes from it.""" with contextlib.suppress(ClientError): await session.get(url) client_task = asyncio.create_task( fetch_stream(f"http://localhost:{port}/stream") ) await client_task gc.collect() client_response_present = any( type(obj).__name__ == "ClientResponse" for obj in gc.garbage ) await session.close() await runner.cleanup() sys.exit(1 if client_response_present else 0) asyncio.run(main()) ================================================ FILE: tests/isolated/check_for_request_leak.py ================================================ import asyncio import gc import socket import sys from typing import NoReturn from aiohttp import ClientSession, web from aiohttp.test_utils import REUSE_ADDRESS gc.set_debug(gc.DEBUG_LEAK) async def main() -> None: app = web.Application() async def handler(request: web.Request) -> NoReturn: await request.json() assert False app.router.add_route("GET", "/json", handler) with socket.create_server(("127.0.0.1", 0), reuse_port=REUSE_ADDRESS) as sock: port = sock.getsockname()[1] runner = web.AppRunner(app) await runner.setup() site = web.SockSite(runner, sock) await site.start() async with ClientSession() as session: async with session.get(f"http://127.0.0.1:{port}/json") as resp: await resp.read() # Give time for the cancelled task to be collected await asyncio.sleep(0.5) gc.collect() request_present = any(type(obj).__name__ == "Request" for obj in gc.garbage) await session.close() await runner.cleanup() sys.exit(1 if request_present else 0) asyncio.run(main()) ================================================ FILE: tests/sample.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: tests/test_base_protocol.py ================================================ import asyncio from contextlib import suppress from unittest import mock import pytest from aiohttp.base_protocol import BaseProtocol async def test_loop() -> None: loop = asyncio.get_event_loop() asyncio.set_event_loop(None) pr = BaseProtocol(loop) assert pr._loop is loop async def test_pause_writing() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop) assert not pr._paused assert pr.writing_paused is False pr.pause_writing() assert pr._paused assert pr.writing_paused is True # type: ignore[unreachable] async def test_pause_reading_no_transport() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop) assert not pr._reading_paused pr.pause_reading() assert not pr._reading_paused async def test_pause_reading_stub_transport() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop) tr = asyncio.Transport() pr.transport = tr assert not pr._reading_paused pr.pause_reading() assert pr._reading_paused async def test_resume_reading_no_transport() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop) pr._reading_paused = True pr.resume_reading() assert pr._reading_paused async def test_resume_reading_stub_transport() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop) tr = asyncio.Transport() pr.transport = tr pr._reading_paused = True pr.resume_reading() assert not pr._reading_paused async def test_resume_writing_no_waiters() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) pr.pause_writing() assert pr._paused pr.resume_writing() assert not pr._paused async def test_resume_writing_waiter_done() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) waiter = mock.Mock(done=mock.Mock(return_value=True)) pr._drain_waiter = waiter pr._paused = True pr.resume_writing() assert not pr._paused assert waiter.mock_calls == [mock.call.done()] async def test_connection_made() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() assert pr.transport is None pr.connection_made(tr) assert pr.transport is not None async def test_connection_lost_not_paused() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) assert pr.connected pr.connection_lost(None) assert pr.transport is None assert not pr.connected async def test_connection_lost_paused_without_waiter() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) assert pr.connected pr.pause_writing() pr.connection_lost(None) assert pr.transport is None assert not pr.connected async def test_connection_lost_waiter_done() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) pr._paused = True waiter = mock.Mock(done=mock.Mock(return_value=True)) pr._drain_waiter = waiter pr.connection_lost(None) assert pr._drain_waiter is None assert waiter.mock_calls == [mock.call.done()] # type: ignore[unreachable] async def test_drain_lost() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) pr.connection_lost(None) with pytest.raises(ConnectionResetError): await pr._drain_helper() async def test_drain_not_paused() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) assert pr._drain_waiter is None await pr._drain_helper() assert pr._drain_waiter is None async def test_resume_drain_waited() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) pr.pause_writing() t = loop.create_task(pr._drain_helper()) await asyncio.sleep(0) assert pr._drain_waiter is not None pr.resume_writing() await t assert pr._drain_waiter is None async def test_lost_drain_waited_ok() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) pr.pause_writing() t = loop.create_task(pr._drain_helper()) await asyncio.sleep(0) assert pr._drain_waiter is not None pr.connection_lost(None) await t assert pr._drain_waiter is None async def test_lost_drain_waited_exception() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) pr.pause_writing() t = loop.create_task(pr._drain_helper()) await asyncio.sleep(0) assert pr._drain_waiter is not None exc = RuntimeError() pr.connection_lost(exc) with pytest.raises(ConnectionError, match=r"^Connection lost$") as cm: await t assert cm.value.__cause__ is exc assert pr._drain_waiter is None async def test_lost_drain_cancelled() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) pr.pause_writing() fut = loop.create_future() async def wait() -> None: fut.set_result(None) await pr._drain_helper() t = loop.create_task(wait()) await fut t.cancel() assert pr._drain_waiter is not None pr.connection_lost(None) with suppress(asyncio.CancelledError): await t assert pr._drain_waiter is None async def test_resume_drain_cancelled() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) pr.pause_writing() fut = loop.create_future() async def wait() -> None: fut.set_result(None) await pr._drain_helper() t = loop.create_task(wait()) await fut t.cancel() assert pr._drain_waiter is not None pr.resume_writing() with suppress(asyncio.CancelledError): await t assert pr._drain_waiter is None async def test_parallel_drain_race_condition() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) pr.pause_writing() ts = [loop.create_task(pr._drain_helper()) for _ in range(5)] assert not (await asyncio.wait(ts, timeout=0.5))[ 0 ], "All draining tasks must be pending" assert pr._drain_waiter is not None pr.resume_writing() await asyncio.gather(*ts) assert pr._drain_waiter is None ================================================ FILE: tests/test_benchmarks_client.py ================================================ """codspeed benchmarks for HTTP client.""" import asyncio import pytest from pytest_codspeed import BenchmarkFixture from yarl import URL from aiohttp import hdrs, request, web from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer def test_one_hundred_simple_get_requests( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 simple GET requests.""" message_count = 100 async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): await client.get("/") await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_simple_get_requests_alternating_clients( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 simple GET requests with alternating clients.""" message_count = 100 async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client1 = await aiohttp_client(app) client2 = await aiohttp_client(app) for i in range(message_count): if i % 2 == 0: await client1.get("/") else: await client2.get("/") await client1.close() await client2.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_simple_get_requests_no_session( loop: asyncio.AbstractEventLoop, aiohttp_server: AiohttpServer, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 simple GET requests without a session.""" message_count = 100 async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) server = loop.run_until_complete(aiohttp_server(app)) url = URL(f"http://{server.host}:{server.port}/") async def run_client_benchmark() -> None: for _ in range(message_count): async with request("GET", url): pass @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_simple_get_requests_multiple_methods_route( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 simple GET requests on a route with multiple methods.""" message_count = 100 async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() # GET intentionally registered last to ensure time complexity # of the route lookup is benchmarked for method in ("DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT", "GET"): app.router.add_route(method, "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): await client.get("/") await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_get_requests_with_1024_chunked_payload( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 GET requests with a small payload of 1024 bytes.""" message_count = 100 payload = b"a" * 1024 async def handler(request: web.Request) -> web.Response: resp = web.Response(body=payload) resp.enable_chunked_encoding() return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") await resp.read() await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_get_requests_with_30000_chunked_payload( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 GET requests with a payload of 30000 bytes.""" message_count = 100 payload = b"a" * 30000 async def handler(request: web.Request) -> web.Response: resp = web.Response(body=payload) resp.enable_chunked_encoding() return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") await resp.read() await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_get_requests_with_512kib_chunked_payload( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 GET requests with a payload of 512KiB using read.""" message_count = 100 payload = b"a" * (2**19) async def handler(request: web.Request) -> web.Response: resp = web.Response(body=payload) resp.enable_chunked_encoding() return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") await resp.read() await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_get_requests_iter_chunks_on_512kib_chunked_payload( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 GET requests with a payload of 512KiB using iter_chunks.""" message_count = 100 payload = b"a" * (2**19) async def handler(request: web.Request) -> web.Response: resp = web.Response(body=payload) resp.enable_chunked_encoding() return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") async for _ in resp.content.iter_chunks(): pass await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) @pytest.mark.usefixtures("parametrize_zlib_backend") def test_get_request_with_251308_compressed_chunked_payload( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark compressed GET requests with a payload of 251308.""" # This payload compresses to 251308 bytes payload = b"".join( [ bytes((*range(0, i), *range(i, 0, -1))) for _ in range(255) for i in range(255) ] ) async def handler(request: web.Request) -> web.Response: resp = web.Response(body=payload, zlib_executor_size=16384) resp.enable_compression() return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) resp = await client.get("/") await resp.read() await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_get_requests_with_1024_content_length_payload( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 GET requests with a small payload of 1024 bytes.""" message_count = 100 payload = b"a" * 1024 headers = {hdrs.CONTENT_LENGTH: str(len(payload))} async def handler(request: web.Request) -> web.Response: return web.Response(body=payload, headers=headers) app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") await resp.read() await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_get_requests_with_30000_content_length_payload( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 GET requests with a payload of 30000 bytes.""" message_count = 100 payload = b"a" * 30000 headers = {hdrs.CONTENT_LENGTH: str(len(payload))} async def handler(request: web.Request) -> web.Response: return web.Response(body=payload, headers=headers) app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") await resp.read() await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_get_requests_with_512kib_content_length_payload( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 GET requests with a payload of 512KiB.""" message_count = 100 payload = b"a" * (2**19) headers = {hdrs.CONTENT_LENGTH: str(len(payload))} async def handler(request: web.Request) -> web.Response: return web.Response(body=payload, headers=headers) app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") await resp.read() await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_simple_post_requests( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 simple POST requests.""" message_count = 100 async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): await client.post("/", data=b"any") await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_one_hundred_json_post_requests( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 100 JSON POST requests that check the content-type.""" message_count = 100 async def handler(request: web.Request) -> web.Response: _ = request.content_type _ = request.charset return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): await client.post("/", json={"key": "value"}) await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_ten_streamed_responses_iter_any( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 10 streamed responses using iter_any.""" message_count = 10 data = b"x" * 65536 # 64 KiB chunk size async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() await resp.prepare(request) for _ in range(10): await resp.write(data) return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") async for _ in resp.content.iter_any(): pass await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_ten_streamed_responses_iter_chunked_4096( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 10 streamed responses using iter_chunked 4096.""" message_count = 10 data = b"x" * 65536 # 64 KiB chunk size, 4096 iter_chunked async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() await resp.prepare(request) for _ in range(10): await resp.write(data) return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") async for _ in resp.content.iter_chunked(4096): pass await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_ten_streamed_responses_iter_chunked_65536( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 10 streamed responses using iter_chunked 65536.""" message_count = 10 data = b"x" * 65536 # 64 KiB chunk size, 64 KiB iter_chunked async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() await resp.prepare(request) for _ in range(10): await resp.write(data) return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") async for _ in resp.content.iter_chunked(65536): pass await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) def test_ten_streamed_responses_iter_chunks( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark 10 streamed responses using iter_chunks.""" message_count = 10 data = b"x" * 65536 # 64 KiB chunk size async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() await resp.prepare(request) for _ in range(10): await resp.write(data) return resp app = web.Application() app.router.add_route("GET", "/", handler) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): resp = await client.get("/") async for _ in resp.content.iter_chunks(): pass await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) ================================================ FILE: tests/test_benchmarks_client_request.py ================================================ """codspeed benchmarks for client requests.""" import asyncio import sys from collections.abc import Callable from http.cookies import BaseCookie from typing import Any from multidict import CIMultiDict from pytest_codspeed import BenchmarkFixture from yarl import URL from aiohttp.client_reqrep import ClientRequest, ClientRequestArgs, ClientResponse from aiohttp.cookiejar import CookieJar from aiohttp.helpers import TimerNoop from aiohttp.http_writer import HttpVersion11 from aiohttp.tracing import Trace if sys.version_info >= (3, 11): from typing import Unpack _RequestMaker = Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest] else: _RequestMaker = Any async def test_client_request_update_cookies( benchmark: BenchmarkFixture, make_client_request: _RequestMaker, ) -> None: url = URL("http://python.org") req = make_client_request("get", url) cookie_jar = CookieJar() cookie_jar.update_cookies({"string": "Another string"}) cookies = cookie_jar.filter_cookies(url) assert cookies["string"].value == "Another string" @benchmark def _run() -> None: req._update_cookies(cookies=cookies) def test_create_client_request_with_cookies( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: url = URL("http://python.org") cookie_jar = CookieJar() cookie_jar.update_cookies({"cookie": "value"}) cookies = cookie_jar.filter_cookies(url) assert cookies["cookie"].value == "value" timer = TimerNoop() traces: list[Trace] = [] headers = CIMultiDict[str]() @benchmark def _run() -> None: ClientRequest( method="get", url=url, loop=loop, params=None, skip_auto_headers=None, response_class=ClientResponse, proxy=None, proxy_auth=None, proxy_headers=None, timer=timer, session=None, # type: ignore[arg-type] ssl=True, traces=traces, trust_env=False, server_hostname=None, headers=headers, data=None, cookies=cookies, auth=None, version=HttpVersion11, compress=False, chunked=None, expect100=False, ) def test_create_client_request_with_headers( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: url = URL("http://python.org") timer = TimerNoop() traces: list[Trace] = [] headers = CIMultiDict({"header": "value", "another": "header"}) cookies = BaseCookie[str]() @benchmark def _run() -> None: ClientRequest( method="get", url=url, loop=loop, params=None, skip_auto_headers=None, response_class=ClientResponse, proxy=None, proxy_auth=None, proxy_headers=None, timer=timer, session=None, # type: ignore[arg-type] ssl=True, traces=traces, trust_env=False, server_hostname=None, headers=headers, data=None, cookies=cookies, auth=None, version=HttpVersion11, compress=False, chunked=None, expect100=False, ) def test_send_client_request_one_hundred( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, make_client_request: _RequestMaker, ) -> None: url = URL("http://python.org") async def make_req() -> ClientRequest: """Need async context.""" return make_client_request("get", url) req = loop.run_until_complete(make_req()) class MockTransport(asyncio.Transport): """Mock transport for testing that do no real I/O.""" def is_closing(self) -> bool: """Swallow is_closing.""" return False def write(self, data: bytes | bytearray | memoryview) -> None: """Swallow writes.""" class MockProtocol(asyncio.BaseProtocol): def __init__(self) -> None: self.transport = MockTransport() @property def writing_paused(self) -> bool: return False async def _drain_helper(self) -> None: """Swallow drain.""" def start_timeout(self) -> None: """Swallow start_timeout.""" class MockConnector: def __init__(self) -> None: self.force_close = False class MockConnection: def __init__(self) -> None: self.transport = None self.protocol = MockProtocol() self._connector = MockConnector() conn = MockConnection() async def send_requests() -> None: for _ in range(100): await req._send(conn) # type: ignore[arg-type] @benchmark def _run() -> None: loop.run_until_complete(send_requests()) ================================================ FILE: tests/test_benchmarks_client_ws.py ================================================ """codspeed benchmarks for websocket client.""" import asyncio import pytest from pytest_codspeed import BenchmarkFixture from aiohttp import web from aiohttp._websocket.helpers import MSG_SIZE from aiohttp.pytest_plugin import AiohttpClient def test_one_thousand_round_trip_websocket_text_messages( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark round trip of 1000 WebSocket text messages.""" message_count = 1000 async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) for _ in range(message_count): await ws.send_str("answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) async def run_websocket_benchmark() -> None: client = await aiohttp_client(app) resp = await client.ws_connect("/") for _ in range(message_count): await resp.receive() await resp.close() @benchmark def _run() -> None: loop.run_until_complete(run_websocket_benchmark()) @pytest.mark.parametrize("msg_size", [6, MSG_SIZE * 4], ids=["small", "large"]) def test_one_thousand_round_trip_websocket_binary_messages( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, msg_size: int, ) -> None: """Benchmark round trip of 1000 WebSocket binary messages.""" message_count = 1000 raw_message = b"x" * msg_size async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) for _ in range(message_count): await ws.send_bytes(raw_message) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) async def run_websocket_benchmark() -> None: client = await aiohttp_client(app) resp = await client.ws_connect("/") for _ in range(message_count): await resp.receive() await resp.close() @benchmark def _run() -> None: loop.run_until_complete(run_websocket_benchmark()) def test_one_thousand_large_round_trip_websocket_text_messages( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark round trip of 100 large WebSocket text messages.""" message_count = 100 raw_message = "x" * MSG_SIZE * 4 async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) for _ in range(message_count): await ws.send_str(raw_message) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) async def run_websocket_benchmark() -> None: client = await aiohttp_client(app) resp = await client.ws_connect("/") for _ in range(message_count): await resp.receive() await resp.close() @benchmark def _run() -> None: loop.run_until_complete(run_websocket_benchmark()) @pytest.mark.usefixtures("parametrize_zlib_backend") def test_client_send_large_websocket_compressed_messages( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark send of compressed WebSocket binary messages.""" message_count = 10 raw_message = b"x" * 2**19 # 512 KiB async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) for _ in range(message_count): await ws.receive() await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) async def run_websocket_benchmark() -> None: client = await aiohttp_client(app) resp = await client.ws_connect("/", compress=15) for _ in range(message_count): await resp.send_bytes(raw_message) await resp.close() @benchmark def _run() -> None: loop.run_until_complete(run_websocket_benchmark()) @pytest.mark.usefixtures("parametrize_zlib_backend") def test_client_receive_large_websocket_compressed_messages( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark receive of compressed WebSocket binary messages.""" message_count = 10 raw_message = b"x" * 2**19 # 512 KiB async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) for _ in range(message_count): await ws.send_bytes(raw_message) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) async def run_websocket_benchmark() -> None: client = await aiohttp_client(app) resp = await client.ws_connect("/", compress=15) for _ in range(message_count): await resp.receive() await resp.close() @benchmark def _run() -> None: loop.run_until_complete(run_websocket_benchmark()) ================================================ FILE: tests/test_benchmarks_cookiejar.py ================================================ """codspeed benchmarks for cookies.""" from http.cookies import BaseCookie from pytest_codspeed import BenchmarkFixture from yarl import URL from aiohttp.cookiejar import CookieJar async def test_load_cookies_into_temp_cookiejar(benchmark: BenchmarkFixture) -> None: """Benchmark for creating a temp CookieJar and filtering by URL. This benchmark matches what the client request does when cookies are passed to the request. """ all_cookies: BaseCookie[str] = BaseCookie() url = URL("http://example.com") cookies = {"cookie1": "value1", "cookie2": "value2"} @benchmark def _run() -> None: tmp_cookie_jar = CookieJar() tmp_cookie_jar.update_cookies(cookies) req_cookies = tmp_cookie_jar.filter_cookies(url) all_cookies.load(req_cookies) ================================================ FILE: tests/test_benchmarks_http_websocket.py ================================================ """codspeed benchmarks for http websocket.""" import asyncio import pytest from pytest_codspeed import BenchmarkFixture from aiohttp._websocket.helpers import MSG_SIZE, PACK_LEN3 from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol from aiohttp.http_websocket import WebSocketReader, WebSocketWriter, WSMsgType def test_read_large_binary_websocket_messages( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Read one hundred large binary websocket messages.""" queue = WebSocketDataQueue(BaseProtocol(loop), 2**18, loop=loop) reader = WebSocketReader(queue, max_msg_size=2**18) # PACK3 has a minimum message length of 2**16 bytes. message = b"x" * ((2**16) + 1) msg_length = len(message) first_byte = 0x80 | 0 | WSMsgType.BINARY.value header = PACK_LEN3(first_byte, 127, msg_length) raw_message = header + message feed_data = reader.feed_data @benchmark def _run() -> None: for _ in range(100): feed_data(raw_message) def test_read_one_hundred_websocket_text_messages( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Benchmark reading 100 WebSocket text messages.""" queue = WebSocketDataQueue(BaseProtocol(loop), 2**16, loop=loop) reader = WebSocketReader(queue, max_msg_size=2**16) raw_message = ( b'\x81~\x01!{"id":1,"src":"shellyplugus-c049ef8c30e4","dst":"aios-1453812500' b'8","result":{"name":null,"id":"shellyplugus-c049ef8c30e4","mac":"C049EF8C30E' b'4","slot":1,"model":"SNPL-00116US","gen":2,"fw_id":"20231219-133953/1.1.0-g3' b'4b5d4f","ver":"1.1.0","app":"PlugUS","auth_en":false,"auth_domain":null}}' ) feed_data = reader.feed_data @benchmark def _run() -> None: for _ in range(100): feed_data(raw_message) class MockTransport(asyncio.Transport): """Mock transport for testing that do no real I/O.""" def is_closing(self) -> bool: """Swallow is_closing.""" return False def write(self, data: bytes | bytearray | memoryview) -> None: """Swallow writes.""" class MockProtocol(BaseProtocol): async def _drain_helper(self) -> None: """Swallow drain.""" def test_send_one_hundred_websocket_text_messages( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Benchmark sending 100 WebSocket text messages.""" writer = WebSocketWriter(MockProtocol(loop=loop), MockTransport()) raw_message = b"Hello, World!" * 100 async def _send_one_hundred_websocket_text_messages() -> None: for _ in range(100): await writer.send_frame(raw_message, WSMsgType.TEXT) @benchmark def _run() -> None: loop.run_until_complete(_send_one_hundred_websocket_text_messages()) def test_send_one_hundred_large_websocket_text_messages( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Benchmark sending 100 WebSocket text messages.""" writer = WebSocketWriter(MockProtocol(loop=loop), MockTransport()) raw_message = b"x" * MSG_SIZE * 4 async def _send_one_hundred_websocket_text_messages() -> None: for _ in range(100): await writer.send_frame(raw_message, WSMsgType.TEXT) @benchmark def _run() -> None: loop.run_until_complete(_send_one_hundred_websocket_text_messages()) def test_send_one_hundred_websocket_text_messages_with_mask( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Benchmark sending 100 masked WebSocket text messages.""" writer = WebSocketWriter(MockProtocol(loop=loop), MockTransport(), use_mask=True) raw_message = b"Hello, World!" * 100 async def _send_one_hundred_websocket_text_messages() -> None: for _ in range(100): await writer.send_frame(raw_message, WSMsgType.TEXT) @benchmark def _run() -> None: loop.run_until_complete(_send_one_hundred_websocket_text_messages()) @pytest.mark.usefixtures("parametrize_zlib_backend") def test_send_one_hundred_websocket_compressed_messages( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture ) -> None: """Benchmark sending 100 WebSocket compressed messages.""" writer = WebSocketWriter(MockProtocol(loop=loop), MockTransport(), compress=15) raw_message = b"Hello, World!" * 100 async def _send_one_hundred_websocket_compressed_messages() -> None: for _ in range(100): await writer.send_frame(raw_message, WSMsgType.BINARY) @benchmark def _run() -> None: loop.run_until_complete(_send_one_hundred_websocket_compressed_messages()) ================================================ FILE: tests/test_benchmarks_http_writer.py ================================================ """codspeed benchmarks for http writer.""" from multidict import CIMultiDict from pytest_codspeed import BenchmarkFixture from aiohttp import hdrs from aiohttp.http_writer import _serialize_headers def test_serialize_headers(benchmark: BenchmarkFixture) -> None: """Benchmark 100 calls to _serialize_headers.""" status_line = "HTTP/1.1 200 OK" headers = CIMultiDict( { hdrs.CONTENT_TYPE: "text/plain", hdrs.CONTENT_LENGTH: "100", hdrs.CONNECTION: "keep-alive", hdrs.DATE: "Mon, 23 May 2005 22:38:34 GMT", hdrs.SERVER: "Test/1.0", hdrs.CONTENT_ENCODING: "gzip", hdrs.VARY: "Accept-Encoding", hdrs.CACHE_CONTROL: "no-cache", hdrs.PRAGMA: "no-cache", hdrs.EXPIRES: "0", hdrs.LAST_MODIFIED: "Mon, 23 May 2005 22:38:34 GMT", hdrs.ETAG: "1234567890", } ) @benchmark def _run() -> None: for _ in range(100): _serialize_headers(status_line, headers) ================================================ FILE: tests/test_benchmarks_web_fileresponse.py ================================================ """codspeed benchmarks for the web file responses.""" import asyncio import pathlib from multidict import CIMultiDict from pytest_codspeed import BenchmarkFixture from aiohttp import ClientResponse, web from aiohttp.pytest_plugin import AiohttpClient def test_simple_web_file_response( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark creating 100 simple web.FileResponse.""" response_count = 100 filepath = pathlib.Path(__file__).parent / "sample.txt" async def handler(request: web.Request) -> web.FileResponse: return web.FileResponse(path=filepath) app = web.Application() app.router.add_route("GET", "/", handler) async def run_file_response_benchmark() -> None: client = await aiohttp_client(app) for _ in range(response_count): await client.get("/") await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_file_response_benchmark()) def test_simple_web_file_sendfile_fallback_response( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark creating 100 simple web.FileResponse without sendfile.""" response_count = 100 filepath = pathlib.Path(__file__).parent / "sample.txt" async def handler(request: web.Request) -> web.FileResponse: transport = request.transport assert transport is not None transport._sendfile_compatible = False # type: ignore[attr-defined] return web.FileResponse(path=filepath) app = web.Application() app.router.add_route("GET", "/", handler) async def run_file_response_benchmark() -> None: client = await aiohttp_client(app) for _ in range(response_count): await client.get("/") await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_file_response_benchmark()) def test_simple_web_file_response_not_modified( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, benchmark: BenchmarkFixture, ) -> None: """Benchmark web.FileResponse that return a 304.""" response_count = 100 filepath = pathlib.Path(__file__).parent / "sample.txt" async def handler(request: web.Request) -> web.FileResponse: return web.FileResponse(path=filepath) app = web.Application() app.router.add_route("GET", "/", handler) async def make_last_modified_header() -> CIMultiDict[str]: client = await aiohttp_client(app) resp = await client.get("/") last_modified = resp.headers["Last-Modified"] headers = CIMultiDict({"If-Modified-Since": last_modified}) return headers async def run_file_response_benchmark( headers: CIMultiDict[str], ) -> ClientResponse: client = await aiohttp_client(app) for _ in range(response_count): resp = await client.get("/", headers=headers) await client.close() return resp # type: ignore[possibly-undefined] headers = loop.run_until_complete(make_last_modified_header()) @benchmark def _run() -> None: resp = loop.run_until_complete(run_file_response_benchmark(headers)) assert resp.status == 304 ================================================ FILE: tests/test_benchmarks_web_middleware.py ================================================ """codspeed benchmarks for web middlewares.""" import asyncio from pytest_codspeed import BenchmarkFixture from aiohttp import web from aiohttp.pytest_plugin import AiohttpClient from aiohttp.typedefs import Handler def test_ten_web_middlewares( benchmark: BenchmarkFixture, loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient, ) -> None: """Benchmark 100 requests with 10 middlewares.""" message_count = 100 async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) class MiddlewareClass: async def call( self, request: web.Request, handler: Handler ) -> web.StreamResponse: return await handler(request) for _ in range(10): app.middlewares.append(MiddlewareClass().call) async def run_client_benchmark() -> None: client = await aiohttp_client(app) for _ in range(message_count): await client.get("/") await client.close() @benchmark def _run() -> None: loop.run_until_complete(run_client_benchmark()) ================================================ FILE: tests/test_benchmarks_web_response.py ================================================ """codspeed benchmarks for the web responses.""" from pytest_codspeed import BenchmarkFixture from aiohttp import web def test_simple_web_response(benchmark: BenchmarkFixture) -> None: """Benchmark creating 100 simple web.Response.""" response_count = 100 @benchmark def _run() -> None: for _ in range(response_count): web.Response() def test_web_response_with_headers(benchmark: BenchmarkFixture) -> None: """Benchmark creating 100 web.Response with headers.""" response_count = 100 headers = { "Content-Type": "text/plain", "Server": "aiohttp", "Date": "Sun, 01 Aug 2021 12:00:00 GMT", } @benchmark def _run() -> None: for _ in range(response_count): web.Response(headers=headers) def test_web_response_with_bytes_body( benchmark: BenchmarkFixture, ) -> None: """Benchmark creating 100 web.Response with bytes.""" response_count = 100 @benchmark def _run() -> None: for _ in range(response_count): web.Response(body=b"Hello, World!") def test_web_response_with_text_body(benchmark: BenchmarkFixture) -> None: """Benchmark creating 100 web.Response with text.""" response_count = 100 @benchmark def _run() -> None: for _ in range(response_count): web.Response(text="Hello, World!") def test_simple_web_stream_response(benchmark: BenchmarkFixture) -> None: """Benchmark creating 100 simple web.StreamResponse.""" response_count = 100 @benchmark def _run() -> None: for _ in range(response_count): web.StreamResponse() ================================================ FILE: tests/test_benchmarks_web_urldispatcher.py ================================================ """codspeed benchmarks for the URL dispatcher.""" import asyncio import json import pathlib import random import string from pathlib import Path from typing import NoReturn, cast from unittest import mock import pytest from multidict import CIMultiDict, CIMultiDictProxy from pytest_codspeed import BenchmarkFixture from yarl import URL import aiohttp from aiohttp import web from aiohttp.http import HttpVersion, RawRequestMessage @pytest.fixture def github_urls() -> list[str]: """GitHub api urls.""" # The fixture provides OpenAPI generated info for github. # To update the local data file please run the following command: # $ curl https://raw.githubusercontent.com/github/rest-api-description/refs/heads/main/descriptions/api.github.com/api.github.com.json | jq ".paths | keys" > github-urls.json here = Path(__file__).parent with (here / "github-urls.json").open() as f: urls = json.load(f) return cast(list[str], urls) def _mock_request(method: str, path: str) -> web.Request: message = RawRequestMessage( method, path, HttpVersion(1, 1), CIMultiDictProxy(CIMultiDict()), (), False, None, False, False, URL(path), ) return web.Request( message, mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock() ) def test_resolve_root_route( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve top level PlainResources route 100 times.""" resolve_count = 100 async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() app.router.add_route("GET", "/", handler) app.freeze() router = app.router request = _mock_request(method="GET", path="/") async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for _ in range(resolve_count): ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ret.get_info()["path"] == "/", ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_root_route_with_many_fixed_routes( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve top level PlainResources route 100 times.""" resolve_count = 100 async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() app.router.add_route("GET", "/", handler) for count in range(250): app.router.add_route("GET", f"/api/server/dispatch/{count}/update", handler) app.router.add_route("GET", f"/api/server/dispatch/{count}", handler) app.router.add_route("GET", "/api/server/dispatch", handler) app.router.add_route("GET", "/api/server", handler) app.router.add_route("GET", "/api", handler) app.freeze() router = app.router request = _mock_request(method="GET", path="/") async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for _ in range(resolve_count): ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ret.get_info()["path"] == "/", ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_static_root_route( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve top level StaticResource route 100 times.""" resolve_count = 100 app = web.Application() here = pathlib.Path(aiohttp.__file__).parent app.router.add_static("/", here) app.freeze() router = app.router request = _mock_request(method="GET", path="/") async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for _ in range(resolve_count): ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ret.get_info()["directory"] == here, ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_single_fixed_url_with_many_routes( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve PlainResources route 100 times.""" resolve_count = 100 async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() for count in range(250): app.router.add_route("GET", f"/api/server/dispatch/{count}/update", handler) app.freeze() router = app.router request = _mock_request(method="GET", path="/api/server/dispatch/1/update") async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for _ in range(resolve_count): ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ret.get_info()["path"] == "/api/server/dispatch/1/update", ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_multiple_fixed_url_with_many_routes( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve 250 different PlainResources routes.""" async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() for count in range(250): app.router.add_route("GET", f"/api/server/dispatch/{count}/update", handler) app.freeze() router = app.router requests = [ _mock_request(method="GET", path=f"/api/server/dispatch/{count}/update") for count in range(250) ] async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ret.get_info()["path"] == "/api/server/dispatch/249/update", ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_multiple_level_fixed_url_with_many_routes( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve 1024 different PlainResources routes.""" async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() urls = [ f"/api/{a}/{b}/{c}/{d}/{e}/update" for a in ("a", "b", "c", "d") for b in ("e", "f", "g", "h") for c in ("i", "j", "k", "l") for d in ("m", "n", "o", "p") for e in ("n", "o", "p", "q") ] for url in urls: app.router.add_route("GET", url, handler) app.freeze() router = app.router requests = [(_mock_request(method="GET", path=url), url) for url in urls] async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request, path in requests: ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ret.get_info()["path"] == url, ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_dynamic_resource_url_with_many_static_routes( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve different a DynamicResource when there are 250 PlainResources registered.""" async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() for count in range(250): app.router.add_route("GET", f"/api/server/other/{count}/update", handler) app.router.add_route("GET", "/api/server/dispatch/{customer}/update", handler) app.freeze() router = app.router requests = [ _mock_request(method="GET", path=f"/api/server/dispatch/{customer}/update") for customer in range(250) ] async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ( ret.get_info()["formatter"] == "/api/server/dispatch/{customer}/update" ), ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_dynamic_resource_url_with_many_dynamic_routes( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve different a DynamicResource when there are 250 DynamicResources registered.""" async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() for count in range(250): app.router.add_route( "GET", f"/api/server/other/{{customer}}/update{count}", handler ) app.router.add_route("GET", "/api/server/dispatch/{customer}/update", handler) app.freeze() router = app.router requests = [ _mock_request(method="GET", path=f"/api/server/dispatch/{customer}/update") for customer in range(250) ] async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ( ret.get_info()["formatter"] == "/api/server/dispatch/{customer}/update" ), ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_dynamic_resource_url_with_many_dynamic_routes_with_common_prefix( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve different a DynamicResource when there are 250 DynamicResources registered with the same common prefix.""" async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() for count in range(250): app.router.add_route("GET", f"/api/{{customer}}/show_{count}", handler) app.router.add_route("GET", "/api/{customer}/update", handler) app.freeze() router = app.router requests = [ _mock_request(method="GET", path=f"/api/{customer}/update") for customer in range(250) ] async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ret.get_info()["formatter"] == "/api/{customer}/update", ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_gitapi( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, github_urls: list[str], ) -> None: """Resolve DynamicResource for simulated github API.""" async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() for url in github_urls: app.router.add_get(url, handler) app.freeze() router = app.router # PR reviews API was selected absolutely voluntary. # It is not any special but sits somewhere in the middle of the urls list. # If anybody has better idea please suggest. alnums = string.ascii_letters + string.digits requests = [] for i in range(250): owner = "".join(random.sample(alnums, 10)) repo = "".join(random.sample(alnums, 10)) pull_number = random.randint(0, 250) requests.append( _mock_request( method="GET", path=f"/repos/{owner}/{repo}/pulls/{pull_number}/reviews" ) ) async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ( ret.get_info()["formatter"] == "/repos/{owner}/{repo}/pulls/{pull_number}/reviews" ), ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_gitapi_subapps( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, github_urls: list[str], ) -> None: """Resolve DynamicResource for simulated github API, grouped in subapps.""" async def handler(request: web.Request) -> NoReturn: assert False subapps = { "gists": web.Application(), "orgs": web.Application(), "projects": web.Application(), "repos": web.Application(), "teams": web.Application(), "user": web.Application(), "users": web.Application(), } app = web.Application() for url in github_urls: parts = url.split("/") subapp = subapps.get(parts[1]) if subapp is not None: sub_url = "/".join([""] + parts[2:]) if not sub_url: sub_url = "/" subapp.router.add_get(sub_url, handler) else: app.router.add_get(url, handler) for key, subapp in subapps.items(): app.add_subapp("/" + key, subapp) app.freeze() router = app.router # PR reviews API was selected absolutely voluntary. # It is not any special but sits somewhere in the middle of the urls list. # If anybody has better idea please suggest. alnums = string.ascii_letters + string.digits requests = [] for i in range(250): owner = "".join(random.sample(alnums, 10)) repo = "".join(random.sample(alnums, 10)) pull_number = random.randint(0, 250) requests.append( _mock_request( method="GET", path=f"/repos/{owner}/{repo}/pulls/{pull_number}/reviews" ) ) async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ( ret.get_info()["formatter"] == "/repos/{owner}/{repo}/pulls/{pull_number}/reviews" ), ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_gitapi_root( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, github_urls: list[str], ) -> None: """Resolve the plain root for simulated github API.""" async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() for url in github_urls: app.router.add_get(url, handler) app.freeze() router = app.router request = _mock_request(method="GET", path="/") async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for i in range(250): ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ret.get_info()["path"] == "/", ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) def test_resolve_prefix_resources_many_prefix_many_plain( loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture, ) -> None: """Resolve prefix resource (sub_app) whene 250 PlainResources registered and there are 250 subapps that shares the same sub_app path prefix.""" async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() for count in range(250): app.router.add_get(f"/api/server/other/{count}/update", handler) for count in range(250): subapp = web.Application() # sub_apps exists for handling deep enough nested route trees subapp.router.add_get("/deep/enough/sub/path", handler) app.add_subapp(f"/api/path/to/plugin/{count}", subapp) app.freeze() router = app.router requests = [ _mock_request(method="GET", path="/api/path/to/plugin/249/deep/enough/sub/path") for customer in range(250) ] async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) return ret ret = loop.run_until_complete(run_url_dispatcher_benchmark()) assert ret is not None assert ( ret.get_info()["path"] == "/api/path/to/plugin/249/deep/enough/sub/path" ), ret.get_info() @benchmark def _run() -> None: loop.run_until_complete(run_url_dispatcher_benchmark()) ================================================ FILE: tests/test_circular_imports.py ================================================ """Tests for circular imports in all local packages and modules. This ensures all internal packages can be imported right away without any need to import some other module before doing so. This module is based on an idea that pytest uses for self-testing: * https://github.com/sanitizers/octomachinery/blob/be18b54/tests/circular_imports_test.py * https://github.com/pytest-dev/pytest/blob/d18c75b/testing/test_meta.py * https://twitter.com/codewithanthony/status/1229445110510735361 """ import os import pkgutil import socket import subprocess import sys from collections.abc import Generator from itertools import chain from pathlib import Path from types import ModuleType from typing import TYPE_CHECKING, Union import pytest if TYPE_CHECKING: from _pytest.mark.structures import ParameterSet import aiohttp def _mark_aiohttp_worker_for_skipping( importables: list[str], ) -> list[Union[str, "ParameterSet"]]: return [ ( pytest.param( importable, marks=pytest.mark.skipif( not hasattr(socket, "AF_UNIX"), reason="It's a UNIX-only module" ), ) if importable == "aiohttp.worker" else importable ) for importable in importables ] def _find_all_importables(pkg: ModuleType) -> list[str]: """Find all importables in the project. Return them in order. """ return sorted( set( chain.from_iterable( _discover_path_importables(Path(p), pkg.__name__) for p in pkg.__path__ ), ), ) def _discover_path_importables( pkg_pth: Path, pkg_name: str, ) -> Generator[str, None, None]: """Yield all importables under a given path and package.""" for dir_path, _d, file_names in os.walk(pkg_pth): pkg_dir_path = Path(dir_path) if pkg_dir_path.parts[-1] == "__pycache__": continue if all(Path(_).suffix != ".py" for _ in file_names): continue rel_pt = pkg_dir_path.relative_to(pkg_pth) pkg_pref = ".".join((pkg_name,) + rel_pt.parts) yield from ( pkg_path for _, pkg_path, _ in pkgutil.walk_packages( (str(pkg_dir_path),), prefix=f"{pkg_pref}.", ) ) @pytest.mark.parametrize( "import_path", _mark_aiohttp_worker_for_skipping(_find_all_importables(aiohttp)), ) def test_no_warnings(import_path: str) -> None: """Verify that exploding importables doesn't explode. This is seeking for any import errors including ones caused by circular imports. """ imp_cmd = ( # fmt: off sys.executable, "-W", "error", # The following deprecation warning is triggered by importing # `gunicorn.util`. Hopefully, it'll get fixed in the future. See # https://github.com/benoitc/gunicorn/issues/2840 for detail. "-W", "ignore:module 'sre_constants' is " "deprecated:DeprecationWarning:pkg_resources._vendor.pyparsing", # Also caused by `gunicorn.util` importing `pkg_resources`: "-W", "ignore:Creating a LegacyVersion has been deprecated and " "will be removed in the next major release:" "DeprecationWarning:", # Deprecation warning emitted by setuptools v67.5.0+ triggered by importing # `gunicorn.util`. "-W", "ignore:pkg_resources is deprecated as an API:" "DeprecationWarning", "-c", f"import {import_path!s}", # fmt: on ) subprocess.check_call(imp_cmd) ================================================ FILE: tests/test_classbasedview.py ================================================ from unittest import mock import pytest from aiohttp import web from aiohttp.web_urldispatcher import View def test_ctor() -> None: request = mock.Mock() view = View(request) assert view.request is request async def test_render_ok() -> None: resp = web.Response(text="OK") class MyView(View): async def get(self) -> web.StreamResponse: return resp request = mock.Mock() request.method = "GET" resp2 = await MyView(request) assert resp is resp2 async def test_render_unknown_method() -> None: class MyView(View): async def get(self) -> web.StreamResponse: assert False options = get request = mock.Mock() request.method = "UNKNOWN" with pytest.raises(web.HTTPMethodNotAllowed) as ctx: await MyView(request) assert ctx.value.headers["allow"] == "GET,OPTIONS" assert ctx.value.status == 405 async def test_render_unsupported_method() -> None: class MyView(View): async def get(self) -> web.StreamResponse: assert False options = delete = get request = mock.Mock() request.method = "POST" with pytest.raises(web.HTTPMethodNotAllowed) as ctx: await MyView(request) assert ctx.value.headers["allow"] == "DELETE,GET,OPTIONS" assert ctx.value.status == 405 ================================================ FILE: tests/test_client_connection.py ================================================ import asyncio import gc from typing import Any from unittest import mock import pytest from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ConnectionKey from aiohttp.connector import BaseConnector, Connection @pytest.fixture def key() -> object: return object() @pytest.fixture def loop() -> Any: return mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True) @pytest.fixture def connector() -> mock.Mock: return mock.Mock() @pytest.fixture def protocol() -> mock.Mock: return mock.Mock(should_close=False) def test_ctor( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: conn = Connection(connector, key, protocol, loop) assert conn.protocol is protocol conn.close() def test_callbacks_on_close( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: conn = Connection(connector, key, protocol, loop) notified = False def cb() -> None: nonlocal notified notified = True conn.add_callback(cb) conn.close() assert notified def test_callbacks_on_release( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: conn = Connection(connector, key, protocol, loop) notified = False def cb() -> None: nonlocal notified notified = True conn.add_callback(cb) conn.release() assert notified def test_callbacks_exception( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: conn = Connection(connector, key, protocol, loop) notified = False def cb1() -> None: raise Exception def cb2() -> None: nonlocal notified notified = True conn.add_callback(cb1) conn.add_callback(cb2) conn.close() assert notified def test_del( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: mock.Mock, ) -> None: loop.is_closed.return_value = False conn = Connection(connector, key, protocol, loop) exc_handler = mock.Mock() loop.set_exception_handler(exc_handler) with pytest.warns(ResourceWarning): del conn gc.collect() connector._release.assert_called_with(key, protocol, should_close=True) # type: ignore[attr-defined] msg = { "client_connection": mock.ANY, # conn was deleted "message": "Unclosed connection", } msg["source_traceback"] = mock.ANY loop.call_exception_handler.assert_called_with(msg) def test_close( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: conn = Connection(connector, key, protocol, loop) assert not conn.closed conn.close() assert conn._protocol is None connector._release.assert_called_with(key, protocol, should_close=True) # type: ignore[attr-defined] assert conn.closed def test_release( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: conn = Connection(connector, key, protocol, loop) assert not conn.closed conn.release() assert protocol.transport is not None assert not protocol.transport.close.called # type: ignore[attr-defined] assert conn._protocol is None connector._release.assert_called_with(key, protocol) # type: ignore[attr-defined] assert conn.closed def test_release_proto_should_close( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: protocol.should_close = True # type: ignore[misc] conn = Connection(connector, key, protocol, loop) assert not conn.closed conn.release() assert protocol.transport is not None assert not protocol.transport.close.called # type: ignore[attr-defined] assert conn._protocol is None connector._release.assert_called_with(key, protocol) # type: ignore[attr-defined] assert conn.closed def test_release_released( connector: BaseConnector, key: ConnectionKey, protocol: ResponseHandler, loop: asyncio.AbstractEventLoop, ) -> None: conn = Connection(connector, key, protocol, loop) conn.release() connector._release.reset_mock() # type: ignore[attr-defined] conn.release() assert protocol.transport is not None assert not protocol.transport.close.called # type: ignore[attr-defined] assert conn._protocol is None assert not connector._release.called # type: ignore[attr-defined] ================================================ FILE: tests/test_client_exceptions.py ================================================ import errno import pickle import pytest from multidict import CIMultiDict, CIMultiDictProxy from yarl import URL from aiohttp import client, client_reqrep class TestClientResponseError: request_info = client.RequestInfo( url=URL("http://example.com"), method="GET", headers=CIMultiDictProxy(CIMultiDict()), real_url=URL("http://example.com"), ) def test_default_status(self) -> None: err = client.ClientResponseError(history=(), request_info=self.request_info) assert err.status == 0 def test_status(self) -> None: err = client.ClientResponseError( status=400, history=(), request_info=self.request_info ) assert err.status == 400 @pytest.mark.xfail(reason="CIMultiDictProxy is not pickleable") def test_pickle(self) -> None: err = client.ClientResponseError(request_info=self.request_info, history=()) for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.request_info == self.request_info assert err2.history == () assert err2.status == 0 assert err2.message == "" assert err2.headers is None err = client.ClientResponseError( request_info=self.request_info, history=(), status=400, message="Something wrong", headers=CIMultiDict(foo="bar"), ) err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.request_info == self.request_info assert err2.history == () assert err2.status == 400 assert err2.message == "Something wrong" # Use headers.get() to verify static type is correct. assert err2.headers.get("foo") == "bar" assert err2.foo == "bar" def test_repr(self) -> None: err = client.ClientResponseError(request_info=self.request_info, history=()) assert repr(err) == (f"ClientResponseError({self.request_info!r}, ())") err = client.ClientResponseError( request_info=self.request_info, history=(), status=400, message="Something wrong", headers=CIMultiDict(), ) assert repr(err) == ( "ClientResponseError(%r, (), status=400, " "message='Something wrong', headers=)" % (self.request_info,) ) def test_str(self) -> None: err = client.ClientResponseError( request_info=self.request_info, history=(), status=400, message="Something wrong", headers=CIMultiDict(), ) assert str(err) == ("400, message='Something wrong', url='http://example.com'") class TestClientConnectorError: connection_key = client_reqrep.ConnectionKey( host="example.com", port=8080, is_ssl=False, ssl=True, proxy=None, proxy_auth=None, proxy_headers_hash=None, ) def test_ctor(self) -> None: err = client.ClientConnectorError( connection_key=self.connection_key, os_error=OSError(errno.ENOENT, "No such file"), ) assert err.errno == errno.ENOENT assert err.strerror == "No such file" assert err.os_error.errno == errno.ENOENT assert err.os_error.strerror == "No such file" assert err.host == "example.com" assert err.port == 8080 assert err.ssl is True def test_pickle(self) -> None: err = client.ClientConnectorError( connection_key=self.connection_key, os_error=OSError(errno.ENOENT, "No such file"), ) err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.errno == errno.ENOENT assert err2.strerror == "No such file" assert err2.os_error.errno == errno.ENOENT assert err2.os_error.strerror == "No such file" assert err2.host == "example.com" assert err2.port == 8080 assert err2.ssl is True assert err2.foo == "bar" def test_repr(self) -> None: os_error = OSError(errno.ENOENT, "No such file") err = client.ClientConnectorError( connection_key=self.connection_key, os_error=os_error ) assert repr(err) == ( f"ClientConnectorError({self.connection_key!r}, {os_error!r})" ) def test_str(self) -> None: err = client.ClientConnectorError( connection_key=self.connection_key, os_error=OSError(errno.ENOENT, "No such file"), ) assert str(err) == ( "Cannot connect to host example.com:8080 ssl:default [No such file]" ) class TestClientConnectorCertificateError: connection_key = client_reqrep.ConnectionKey( host="example.com", port=8080, is_ssl=False, ssl=True, proxy=None, proxy_auth=None, proxy_headers_hash=None, ) def test_ctor(self) -> None: certificate_error = Exception("Bad certificate") err = client.ClientConnectorCertificateError( connection_key=self.connection_key, certificate_error=certificate_error ) assert err.certificate_error == certificate_error assert err.host == "example.com" assert err.port == 8080 assert err.ssl is False def test_pickle(self) -> None: certificate_error = Exception("Bad certificate") err = client.ClientConnectorCertificateError( connection_key=self.connection_key, certificate_error=certificate_error ) err.foo = "bar" for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.certificate_error.args == ("Bad certificate",) assert err2.host == "example.com" assert err2.port == 8080 assert err2.ssl is False assert err2.foo == "bar" def test_repr(self) -> None: certificate_error = Exception("Bad certificate") err = client.ClientConnectorCertificateError( connection_key=self.connection_key, certificate_error=certificate_error ) assert repr(err) == ( "ClientConnectorCertificateError(%r, %r)" % (self.connection_key, certificate_error) ) def test_str(self) -> None: certificate_error = Exception("Bad certificate") err = client.ClientConnectorCertificateError( connection_key=self.connection_key, certificate_error=certificate_error ) assert str(err) == ( "Cannot connect to host example.com:8080 ssl:False" " [Exception: ('Bad certificate',)]" ) def test_oserror(self) -> None: certificate_error = OSError(1, "Bad certificate") err = client.ClientConnectorCertificateError( connection_key=self.connection_key, certificate_error=certificate_error ) assert err.os_error == certificate_error assert err.errno == 1 assert err.strerror == "Bad certificate" class TestServerDisconnectedError: def test_ctor(self) -> None: err = client.ServerDisconnectedError() assert err.message == "Server disconnected" err = client.ServerDisconnectedError(message="No connection") assert err.message == "No connection" def test_pickle(self) -> None: err = client.ServerDisconnectedError(message="No connection") err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.message == "No connection" assert err2.foo == "bar" def test_repr(self) -> None: err = client.ServerDisconnectedError() assert repr(err) == ("ServerDisconnectedError('Server disconnected')") err = client.ServerDisconnectedError(message="No connection") assert repr(err) == "ServerDisconnectedError('No connection')" def test_str(self) -> None: err = client.ServerDisconnectedError() assert str(err) == "Server disconnected" err = client.ServerDisconnectedError(message="No connection") assert str(err) == "No connection" class TestServerFingerprintMismatch: def test_ctor(self) -> None: err = client.ServerFingerprintMismatch( expected=b"exp", got=b"got", host="example.com", port=8080 ) assert err.expected == b"exp" assert err.got == b"got" assert err.host == "example.com" assert err.port == 8080 def test_pickle(self) -> None: err = client.ServerFingerprintMismatch( expected=b"exp", got=b"got", host="example.com", port=8080 ) err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.expected == b"exp" assert err2.got == b"got" assert err2.host == "example.com" assert err2.port == 8080 assert err2.foo == "bar" def test_repr(self) -> None: err = client.ServerFingerprintMismatch(b"exp", b"got", "example.com", 8080) assert repr(err) == ( "" ) class TestInvalidURL: def test_ctor(self) -> None: err = client.InvalidURL(url=":wrong:url:", description=":description:") assert err.url == ":wrong:url:" assert err.description == ":description:" def test_pickle(self) -> None: err = client.InvalidURL(url=":wrong:url:") err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.url == ":wrong:url:" assert err2.foo == "bar" def test_repr_no_description(self) -> None: err = client.InvalidURL(url=":wrong:url:") assert err.args == (":wrong:url:",) assert repr(err) == "" def test_repr_yarl_URL(self) -> None: err = client.InvalidURL(url=URL(":wrong:url:")) assert repr(err) == "" def test_repr_with_description(self) -> None: err = client.InvalidURL(url=":wrong:url:", description=":description:") assert repr(err) == "" def test_str_no_description(self) -> None: err = client.InvalidURL(url=":wrong:url:") assert str(err) == ":wrong:url:" def test_none_description(self) -> None: err = client.InvalidURL(":wrong:url:") assert err.description is None def test_str_with_description(self) -> None: err = client.InvalidURL(url=":wrong:url:", description=":description:") assert str(err) == ":wrong:url: - :description:" ================================================ FILE: tests/test_client_fingerprint.py ================================================ import hashlib from unittest import mock import pytest import aiohttp ssl = pytest.importorskip("ssl") def test_fingerprint_sha256() -> None: sha256 = hashlib.sha256(b"12345678" * 64).digest() fp = aiohttp.Fingerprint(sha256) assert fp.fingerprint == sha256 def test_fingerprint_sha1() -> None: sha1 = hashlib.sha1(b"12345678" * 64).digest() with pytest.raises(ValueError): aiohttp.Fingerprint(sha1) def test_fingerprint_md5() -> None: md5 = hashlib.md5(b"12345678" * 64).digest() with pytest.raises(ValueError): aiohttp.Fingerprint(md5) def test_fingerprint_check_no_ssl() -> None: sha256 = hashlib.sha256(b"12345678" * 64).digest() fp = aiohttp.Fingerprint(sha256) transport = mock.Mock() transport.get_extra_info.return_value = None fp.check(transport) ================================================ FILE: tests/test_client_functional.py ================================================ # HTTP client functional tests against aiohttp.web server import asyncio import datetime import http.cookies import io import json import logging import pathlib import socket import ssl import sys import tarfile import time import zipfile import zlib from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import suppress from typing import Any, NoReturn from unittest import mock try: try: import brotlicffi as brotli except ImportError: import brotli except ImportError: # pragma: no cover brotli = None try: from backports.zstd import ZstdCompressor except ImportError: ZstdCompressor = None # type: ignore[assignment,misc] # pragma: no cover import pytest import trustme from multidict import MultiDict from pytest_mock import MockerFixture from yarl import URL, Query import aiohttp from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, payload, web from aiohttp.abc import AbstractResolver, ResolveResult from aiohttp.client_exceptions import ( ClientResponseError, InvalidURL, InvalidUrlClientError, InvalidUrlRedirectClientError, NonHttpUrlClientError, NonHttpUrlRedirectClientError, SocketTimeoutError, TooManyRedirects, ) from aiohttp.client_reqrep import ClientRequest from aiohttp.compression_utils import DEFAULT_MAX_DECOMPRESS_SIZE from aiohttp.http_exceptions import DecompressSizeError from aiohttp.payload import ( AsyncIterablePayload, BufferedReaderPayload, BytesIOPayload, BytesPayload, StringIOPayload, StringPayload, ) from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.test_utils import TestClient, TestServer, unused_port from aiohttp.typedefs import Handler @pytest.fixture(autouse=True) def cleanup( cleanup_payload_pending_file_closes: None, ) -> None: """Ensure all pending file close operations complete during test teardown.""" @pytest.fixture def here() -> pathlib.Path: return pathlib.Path(__file__).parent @pytest.fixture def fname(here: pathlib.Path) -> pathlib.Path: return here / "conftest.py" @pytest.fixture def headers_echo_client( aiohttp_client: AiohttpClient, ) -> Callable[..., Awaitable[TestClient[web.Request, web.Application]]]: """Create a client with an app that echoes request headers as JSON.""" async def factory(**kwargs: Any) -> TestClient[web.Request, web.Application]: async def handler(request: web.Request) -> web.Response: return web.json_response({"headers": dict(request.headers)}) app = web.Application() app.router.add_get("/", handler) return await aiohttp_client(app, **kwargs) return factory async def test_keepalive_two_requests_success(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"OK") app = web.Application() app.router.add_route("GET", "/", handler) connector = aiohttp.TCPConnector(limit=1) client = await aiohttp_client(app, connector=connector) async with client.get("/") as resp1: await resp1.read() async with client.get("/") as resp2: await resp2.read() assert client._session.connector is not None assert 1 == len(client._session.connector._conns) async def test_keepalive_after_head_requests_success( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"OK") cnt_conn_reuse = 0 async def on_reuseconn(session: object, ctx: object, params: object) -> None: nonlocal cnt_conn_reuse cnt_conn_reuse += 1 trace_config = aiohttp.TraceConfig() trace_config._on_connection_reuseconn.append(on_reuseconn) app = web.Application() app.router.add_route("GET", "/", handler) app.router.add_route("HEAD", "/", handler) connector = aiohttp.TCPConnector(limit=1) client = await aiohttp_client( app, connector=connector, trace_configs=[trace_config] ) async with client.head("/") as resp1: await resp1.read() async with client.get("/") as resp2: await resp2.read() assert 1 == cnt_conn_reuse @pytest.mark.parametrize("status", (101, 204, 304)) async def test_keepalive_after_empty_body_status( aiohttp_client: AiohttpClient, status: int ) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(status=status) cnt_conn_reuse = 0 async def on_reuseconn(session: object, ctx: object, params: object) -> None: nonlocal cnt_conn_reuse cnt_conn_reuse += 1 trace_config = aiohttp.TraceConfig() trace_config._on_connection_reuseconn.append(on_reuseconn) app = web.Application() app.router.add_route("GET", "/", handler) connector = aiohttp.TCPConnector(limit=1) client = await aiohttp_client( app, connector=connector, trace_configs=[trace_config] ) async with client.get("/") as resp1: await resp1.read() async with client.get("/") as resp2: await resp2.read() assert cnt_conn_reuse == 1 @pytest.mark.parametrize("status", (101, 204, 304)) async def test_keepalive_after_empty_body_status_stream_response( aiohttp_client: AiohttpClient, status: int ) -> None: async def handler(request: web.Request) -> web.StreamResponse: stream_response = web.StreamResponse(status=status) await stream_response.prepare(request) return stream_response cnt_conn_reuse = 0 async def on_reuseconn(session: object, ctx: object, params: object) -> None: nonlocal cnt_conn_reuse cnt_conn_reuse += 1 trace_config = aiohttp.TraceConfig() trace_config._on_connection_reuseconn.append(on_reuseconn) app = web.Application() app.router.add_route("GET", "/", handler) connector = aiohttp.TCPConnector(limit=1) client = await aiohttp_client( app, connector=connector, trace_configs=[trace_config] ) async with client.get("/") as resp1: await resp1.read() async with client.get("/") as resp2: await resp2.read() assert cnt_conn_reuse == 1 async def test_keepalive_response_released(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"OK") app = web.Application() app.router.add_route("GET", "/", handler) connector = aiohttp.TCPConnector(limit=1) client = await aiohttp_client(app, connector=connector) resp1 = await client.get("/") resp1.release() resp2 = await client.get("/") resp2.release() assert client._session.connector is not None assert 1 == len(client._session.connector._conns) async def test_upgrade_connection_not_released_after_read( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response( status=101, headers={"Connection": "Upgrade", "Upgrade": "tcp"} ) app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: await resp.read() assert resp.connection is not None assert not resp.closed async def test_keepalive_server_force_close_connection( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body response = web.Response(body=b"OK") response.force_close() return response app = web.Application() app.router.add_route("GET", "/", handler) connector = aiohttp.TCPConnector(limit=1) client = await aiohttp_client(app, connector=connector) resp1 = await client.get("/") resp1.close() resp2 = await client.get("/") resp2.close() assert client._session.connector is not None assert 0 == len(client._session.connector._conns) async def test_keepalive_timeout_async_sleep(unused_port_socket: socket.socket) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"OK") app = web.Application() app.router.add_route("GET", "/", handler) runner = web.AppRunner(app, tcp_keepalive=True, keepalive_timeout=0.001) await runner.setup() site = web.SockSite(runner, unused_port_socket) await site.start() host, port = unused_port_socket.getsockname()[:2] try: async with aiohttp.ClientSession() as sess: resp1 = await sess.get(f"http://{host}:{port}/") await resp1.read() # wait for server keepalive_timeout await asyncio.sleep(0.01) resp2 = await sess.get(f"http://{host}:{port}/") await resp2.read() finally: await asyncio.gather(runner.shutdown(), site.stop()) @pytest.mark.skipif( sys.version_info[:2] == (3, 11), reason="https://github.com/pytest-dev/pytest/issues/10763", ) async def test_keepalive_timeout_sync_sleep(unused_port_socket: socket.socket) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"OK") app = web.Application() app.router.add_route("GET", "/", handler) runner = web.AppRunner(app, tcp_keepalive=True, keepalive_timeout=0.001) await runner.setup() site = web.SockSite(runner, unused_port_socket) await site.start() host, port = unused_port_socket.getsockname()[:2] try: async with aiohttp.ClientSession() as sess: resp1 = await sess.get(f"http://{host}:{port}/") await resp1.read() # wait for server keepalive_timeout # time.sleep is a more challenging scenario than asyncio.sleep time.sleep(0.01) resp2 = await sess.get(f"http://{host}:{port}/") await resp2.read() finally: await asyncio.gather(runner.shutdown(), site.stop()) async def test_release_early(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: await request.read() return web.Response(body=b"OK") app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert resp.closed await resp.wait_for_close() assert client._session.connector is not None assert 1 == len(client._session.connector._conns) async def test_HTTP_304(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(status=304) app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 304 content = await resp.read() assert content == b"" async def test_stream_request_on_server_eof(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text="OK", status=200) app = web.Application() app.add_routes([web.get("/", handler)]) app.add_routes([web.put("/", handler)]) client = await aiohttp_client(app) async def data_gen() -> AsyncIterator[bytes]: for _ in range(2): # pragma: no branch yield b"just data" await asyncio.sleep(0.1) assert client.session.connector is not None async with client.put("/", data=data_gen()) as resp: assert 200 == resp.status assert len(client.session.connector._acquired) == 1 conn = next(iter(client.session.connector._acquired)) async with client.get("/") as resp: assert 200 == resp.status # First connection should have been closed, otherwise server won't know if it # received the full message. conns = next(iter(client.session.connector._conns.values())) assert len(conns) == 1 assert conns[0][0] is not conn async def test_stream_request_on_server_eof_nested( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text="OK", status=200) app = web.Application() app.add_routes([web.get("/", handler)]) app.add_routes([web.put("/", handler)]) client = await aiohttp_client(app) async def data_gen() -> AsyncIterator[bytes]: for _ in range(2): # pragma: no branch yield b"just data" await asyncio.sleep(0.1) assert client.session.connector is not None async with client.put("/", data=data_gen()) as resp: first_conn = next(iter(client.session.connector._acquired)) assert 200 == resp.status async with client.get("/") as resp2: assert 200 == resp2.status # Should be 2 separate connections conns = next(iter(client.session.connector._conns.values())) assert len(conns) == 1 assert first_conn is not None assert not first_conn.is_connected() assert first_conn is not conns[0][0] async def test_HTTP_304_WITH_BODY(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"test", status=304) app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 304 content = await resp.read() assert content == b"" async def test_auto_header_user_agent(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert "aiohttp" in request.headers["user-agent"] return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert 200 == resp.status async def test_skip_auto_headers_user_agent(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert hdrs.USER_AGENT not in request.headers return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/", skip_auto_headers=["user-agent"]) as resp: assert 200 == resp.status async def test_skip_default_auto_headers_user_agent( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: assert hdrs.USER_AGENT not in request.headers return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app, skip_auto_headers=["user-agent"]) async with client.get("/") as resp: assert 200 == resp.status async def test_skip_auto_headers_content_type(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert hdrs.CONTENT_TYPE not in request.headers return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/", skip_auto_headers=["content-type"]) as resp: assert 200 == resp.status async def test_post_data_bytesio(aiohttp_client: AiohttpClient) -> None: data = b"some buffer" async def handler(request: web.Request) -> web.Response: assert len(data) == request.content_length val = await request.read() assert data == val return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) client = await aiohttp_client(app) with io.BytesIO(data) as file_handle: async with client.post("/", data=file_handle) as resp: assert 200 == resp.status async def test_post_data_with_bytesio_file(aiohttp_client: AiohttpClient) -> None: data = b"some buffer" async def handler(request: web.Request) -> web.Response: post_data = await request.post() assert ["file"] == list(post_data.keys()) file_field = post_data["file"] assert isinstance(file_field, web.FileField) assert data == await asyncio.to_thread(file_field.file.read) return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) client = await aiohttp_client(app) with io.BytesIO(data) as file_handle: async with client.post("/", data={"file": file_handle}) as resp: assert 200 == resp.status async def test_post_data_stringio(aiohttp_client: AiohttpClient) -> None: data = "some buffer" async def handler(request: web.Request) -> web.Response: assert len(data) == request.content_length assert request.headers["CONTENT-TYPE"] == "text/plain; charset=utf-8" val = await request.text() assert data == val return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) client = await aiohttp_client(app) async with client.post("/", data=io.StringIO(data)) as resp: assert 200 == resp.status async def test_post_data_textio_encoding(aiohttp_client: AiohttpClient) -> None: data = "текст" async def handler(request: web.Request) -> web.Response: assert request.headers["CONTENT-TYPE"] == "text/plain; charset=koi8-r" val = await request.text() assert data == val return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) client = await aiohttp_client(app) pl = aiohttp.TextIOPayload(io.StringIO(data), encoding="koi8-r") async with client.post("/", data=pl) as resp: assert 200 == resp.status async def test_post_data_zipfile_filelike(aiohttp_client: AiohttpClient) -> None: data = b"This is a zip file payload text file." async def handler(request: web.Request) -> web.Response: val = await request.read() assert data == val, "Transmitted zipfile member failed to match original data." return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) client = await aiohttp_client(app) buf = io.BytesIO() with zipfile.ZipFile(file=buf, mode="w") as zf: with zf.open("payload1.txt", mode="w") as zip_filelike_writing: zip_filelike_writing.write(data) buf.seek(0) zf = zipfile.ZipFile(file=buf, mode="r") async with client.post("/", data=zf.open("payload1.txt")) as resp: assert resp.status == 200 async def test_post_data_tarfile_filelike(aiohttp_client: AiohttpClient) -> None: data = b"This is a tar file payload text file." async def handler(request: web.Request) -> web.Response: val = await request.read() assert data == val, "Transmitted tarfile member failed to match original data." return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) client = await aiohttp_client(app) buf = io.BytesIO() with tarfile.open(fileobj=buf, mode="w") as tf: ti = tarfile.TarInfo(name="payload1.txt") ti.size = len(data) tf.addfile(tarinfo=ti, fileobj=io.BytesIO(data)) # Random-access tarfile. buf.seek(0) tf = tarfile.open(fileobj=buf, mode="r:") async with client.post("/", data=tf.extractfile("payload1.txt")) as resp: assert resp.status == 200 # Streaming tarfile. buf.seek(0) tf = tarfile.open(fileobj=buf, mode="r|") for entry in tf: async with client.post("/", data=tf.extractfile(entry)) as resp: assert resp.status == 200 async def test_post_bytes_data_content_length_from_body( aiohttp_client: AiohttpClient, ) -> None: """Test that Content-Length is set from body payload size when sending bytes.""" data = b"test payload data" async def handler(request: web.Request) -> web.Response: # Verify Content-Length header was set correctly assert request.content_length == len(data) assert request.headers.get("Content-Length") == str(len(data)) # Verify we can read the data val = await request.read() assert data == val return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) client = await aiohttp_client(app) # Send bytes data - this should trigger the code path where # Content-Length is set from body.size in update_transfer_encoding async with client.post("/", data=data) as resp: assert resp.status == 200 async def test_post_custom_payload_without_content_length( aiohttp_client: AiohttpClient, ) -> None: """Test that Content-Length is set from payload.size when not explicitly provided.""" data = b"custom payload data" async def handler(request: web.Request) -> web.Response: # Verify Content-Length header was set from payload size assert request.content_length == len(data) assert request.headers.get("Content-Length") == str(len(data)) # Verify we can read the data val = await request.read() assert data == val return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) client = await aiohttp_client(app) # Create a BytesPayload directly - this ensures we test the path # where update_transfer_encoding sets Content-Length from body.size bytes_payload = payload.BytesPayload(data) # Don't set Content-Length header explicitly async with client.post("/", data=bytes_payload) as resp: assert resp.status == 200 async def test_ssl_client( aiohttp_server: AiohttpServer, ssl_ctx: ssl.SSLContext, aiohttp_client: AiohttpClient, client_ssl_ctx: ssl.SSLContext, ) -> None: connector = aiohttp.TCPConnector(ssl=client_ssl_ctx) async def handler(request: web.Request) -> web.Response: return web.Response(text="Test message") app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app, ssl=ssl_ctx) client = await aiohttp_client(server, connector=connector) async with client.get("/") as resp: assert resp.status == 200 txt = await resp.text() assert txt == "Test message" @pytest.mark.skipif( sys.version_info < (3, 11), reason="ssl_shutdown_timeout requires Python 3.11+" ) async def test_ssl_client_shutdown_timeout( aiohttp_server: AiohttpServer, ssl_ctx: ssl.SSLContext, aiohttp_client: AiohttpClient, client_ssl_ctx: ssl.SSLContext, ) -> None: # Test that ssl_shutdown_timeout is properly used during connection closure with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): connector = aiohttp.TCPConnector(ssl=client_ssl_ctx, ssl_shutdown_timeout=0.1) async def streaming_handler(request: web.Request) -> NoReturn: # Create a streaming response that continuously sends data response = web.StreamResponse() await response.prepare(request) # Keep sending data until connection is closed while True: await response.write(b"data chunk\n") await asyncio.sleep(0.01) # Small delay between chunks assert False, "not reached" app = web.Application() app.router.add_route("GET", "/stream", streaming_handler) server = await aiohttp_server(app, ssl=ssl_ctx) client = await aiohttp_client(server, connector=connector) # Verify the connector has the correct timeout assert connector._ssl_shutdown_timeout == 0.1 # Start a streaming request to establish SSL connection with active data transfer resp = await client.get("/stream") assert resp.status == 200 # Create a background task that continuously reads data async def read_loop() -> None: while True: # Read "data chunk\n" await resp.content.read(11) read_task = asyncio.create_task(read_loop()) await asyncio.sleep(0) # Yield control to ensure read_task starts # Record the time before closing start_time = time.monotonic() # Now close the connector while the stream is still active # This will test the ssl_shutdown_timeout during an active connection await connector.close() # Verify the connection was closed within a reasonable time # Should be close to ssl_shutdown_timeout (0.1s) but allow some margin elapsed = time.monotonic() - start_time assert elapsed < 0.3, f"Connection closure took too long: {elapsed}s" read_task.cancel() with suppress(asyncio.CancelledError): await read_task assert read_task.done(), "Read task should be cancelled after connection closure" async def test_ssl_client_alpn( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient, ssl_ctx: ssl.SSLContext, ) -> None: async def handler(request: web.Request) -> web.Response: assert request.transport is not None sslobj = request.transport.get_extra_info("ssl_object") return web.Response(text=sslobj.selected_alpn_protocol()) app = web.Application() app.router.add_route("GET", "/", handler) ssl_ctx.set_alpn_protocols(("http/1.1",)) server = await aiohttp_server(app, ssl=ssl_ctx) connector = aiohttp.TCPConnector(ssl=False) client = await aiohttp_client(server, connector=connector) async with client.get("/") as resp: assert resp.status == 200 txt = await resp.text() assert txt == "http/1.1" async def test_tcp_connector_fingerprint_ok( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient, ssl_ctx: ssl.SSLContext, tls_certificate_fingerprint_sha256: bytes, ) -> None: tls_fingerprint = Fingerprint(tls_certificate_fingerprint_sha256) async def handler(request: web.Request) -> web.Response: return web.Response(text="Test message") connector = aiohttp.TCPConnector(ssl=tls_fingerprint) app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app, ssl=ssl_ctx) client = await aiohttp_client(server, connector=connector) async with client.get("/") as resp: assert resp.status == 200 async def test_tcp_connector_fingerprint_fail( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient, ssl_ctx: ssl.SSLContext, tls_certificate_fingerprint_sha256: bytes, ) -> None: async def handler(request: web.Request) -> NoReturn: assert False bad_fingerprint = b"\x00" * len(tls_certificate_fingerprint_sha256) connector = aiohttp.TCPConnector(ssl=Fingerprint(bad_fingerprint)) app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app, ssl=ssl_ctx) client = await aiohttp_client(server, connector=connector) with pytest.raises(ServerFingerprintMismatch) as cm: await client.get("/") exc = cm.value assert exc.expected == bad_fingerprint assert exc.got == tls_certificate_fingerprint_sha256 async def test_format_task_get(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"OK") app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) client = aiohttp.ClientSession() task = asyncio.create_task(client.get(server.make_url("/"))) assert f"{task}".startswith(" None: async def handler(request: web.Request) -> web.Response: assert "q=t est" in request.rel_url.query_string return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/", params="q=t+est") as resp: assert 200 == resp.status async def test_params_and_query_string(aiohttp_client: AiohttpClient) -> None: """Test combining params with an existing query_string.""" async def handler(request: web.Request) -> web.Response: assert request.rel_url.query_string == "q=abc&q=test&d=dog" return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/?q=abc", params="q=test&d=dog") as resp: assert resp.status == 200 @pytest.mark.parametrize("params", [None, "", {}, MultiDict()]) async def test_empty_params_and_query_string( aiohttp_client: AiohttpClient, params: Query ) -> None: """Test combining empty params with an existing query_string.""" async def handler(request: web.Request) -> web.Response: assert request.rel_url.query_string == "q=abc" return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/?q=abc", params=params) as resp: assert resp.status == 200 async def test_drop_params_on_redirect(aiohttp_client: AiohttpClient) -> None: async def handler_redirect(request: web.Request) -> web.Response: return web.Response(status=301, headers={"Location": "/ok?a=redirect"}) async def handler_ok(request: web.Request) -> web.Response: assert request.rel_url.query_string == "a=redirect" return web.Response(status=200) app = web.Application() app.router.add_route("GET", "/ok", handler_ok) app.router.add_route("GET", "/redirect", handler_redirect) client = await aiohttp_client(app) async with client.get("/redirect", params={"a": "initial"}) as resp: assert resp.status == 200 async def test_drop_fragment_on_redirect(aiohttp_client: AiohttpClient) -> None: async def handler_redirect(request: web.Request) -> web.Response: return web.Response(status=301, headers={"Location": "/ok#fragment"}) async def handler_ok(request: web.Request) -> web.Response: return web.Response(status=200) app = web.Application() app.router.add_route("GET", "/ok", handler_ok) app.router.add_route("GET", "/redirect", handler_redirect) client = await aiohttp_client(app) async with client.get("/redirect") as resp: assert resp.status == 200 assert resp.url.path == "/ok" async def test_drop_fragment(aiohttp_client: AiohttpClient) -> None: async def handler_ok(request: web.Request) -> web.Response: return web.Response(status=200) app = web.Application() app.router.add_route("GET", "/ok", handler_ok) client = await aiohttp_client(app) async with client.get("/ok#fragment") as resp: assert resp.status == 200 assert resp.url.path == "/ok" async def test_history(aiohttp_client: AiohttpClient) -> None: async def handler_redirect(request: web.Request) -> web.Response: return web.Response(status=301, headers={"Location": "/ok"}) async def handler_ok(request: web.Request) -> web.Response: return web.Response(status=200) app = web.Application() app.router.add_route("GET", "/ok", handler_ok) app.router.add_route("GET", "/redirect", handler_redirect) client = await aiohttp_client(app) async with client.get("/ok") as resp: assert len(resp.history) == 0 assert resp.status == 200 async with client.get("/redirect") as resp_redirect: assert len(resp_redirect.history) == 1 assert resp_redirect.history[0].status == 301 assert resp_redirect.status == 200 async def test_keepalive_closed_by_server(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body resp = web.Response(body=b"OK") resp.force_close() return resp app = web.Application() app.router.add_route("GET", "/", handler) connector = aiohttp.TCPConnector(limit=1) client = await aiohttp_client(app, connector=connector) async with client.get("/") as resp1: val1 = await resp1.read() assert val1 == b"OK" async with client.get("/") as resp2: val2 = await resp2.read() assert val2 == b"OK" assert client._session.connector is not None assert 0 == len(client._session.connector._conns) async def test_wait_for(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"OK") app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await asyncio.wait_for(client.get("/"), 10) assert resp.status == 200 txt = await resp.text() assert txt == "OK" async def test_raw_headers(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 raw_headers = tuple((bytes(h), bytes(v)) for h, v in resp.raw_headers) assert raw_headers == ( (b"Content-Length", b"0"), (b"Date", mock.ANY), (b"Server", mock.ANY), ) async def test_host_header_first(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert list(request.headers)[0] == hdrs.HOST return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 async def test_empty_header_values(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response() resp.headers["X-Empty"] = "" return resp app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 raw_headers = tuple((bytes(h), bytes(v)) for h, v in resp.raw_headers) assert raw_headers == ( (b"X-Empty", b""), (b"Content-Length", b"0"), (b"Date", mock.ANY), (b"Server", mock.ANY), ) async def test_204_with_gzipped_content_encoding(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse(status=204) resp.content_length = 0 resp.content_type = "application/json" # resp.enable_compression(web.ContentCoding.gzip) resp.headers["Content-Encoding"] = "gzip" await resp.prepare(request) return resp app = web.Application() app.router.add_route("DELETE", "/", handler) client = await aiohttp_client(app) async with client.delete("/") as resp: assert resp.status == 204 assert resp.closed async def test_timeout_on_reading_headers( aiohttp_client: AiohttpClient, mocker: MockerFixture ) -> None: async def handler(request: web.Request) -> NoReturn: await asyncio.sleep(0.1) assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) with pytest.raises(asyncio.TimeoutError): await client.get("/", timeout=aiohttp.ClientTimeout(total=0.01)) async def test_timeout_on_conn_reading_headers( aiohttp_client: AiohttpClient, mocker: MockerFixture ) -> None: # tests case where user did not set a connection timeout async def handler(request: web.Request) -> NoReturn: await asyncio.sleep(0.1) assert False app = web.Application() app.router.add_route("GET", "/", handler) conn = aiohttp.TCPConnector() client = await aiohttp_client(app, connector=conn) with pytest.raises(asyncio.TimeoutError): await client.get("/", timeout=aiohttp.ClientTimeout(total=0.01)) async def test_timeout_on_session_read_timeout( aiohttp_client: AiohttpClient, mocker: MockerFixture ) -> None: async def handler(request: web.Request) -> NoReturn: await asyncio.sleep(0.1) assert False app = web.Application() app.router.add_route("GET", "/", handler) conn = aiohttp.TCPConnector() client = await aiohttp_client( app, connector=conn, timeout=aiohttp.ClientTimeout(sock_read=0.01) ) with pytest.raises(asyncio.TimeoutError): await client.get("/") async def test_read_timeout_between_chunks( aiohttp_client: AiohttpClient, mocker: MockerFixture ) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = aiohttp.web.StreamResponse() await resp.prepare(request) # write data 4 times, with pauses. Total time 2 seconds. for _ in range(4): await asyncio.sleep(0.5) await resp.write(b"data\n") return resp app = web.Application() app.add_routes([web.get("/", handler)]) # A timeout of 0.2 seconds should apply per read. timeout = aiohttp.ClientTimeout(sock_read=1) client = await aiohttp_client(app, timeout=timeout) res = b"" async with client.get("/") as resp: res += await resp.read() assert res == b"data\n" * 4 async def test_read_timeout_on_reading_chunks( aiohttp_client: AiohttpClient, mocker: MockerFixture ) -> None: async def handler(request: web.Request) -> NoReturn: resp = aiohttp.web.StreamResponse() await resp.prepare(request) await resp.write(b"data\n") await asyncio.sleep(1) assert False app = web.Application() app.add_routes([web.get("/", handler)]) # A timeout of 0.2 seconds should apply per read. timeout = aiohttp.ClientTimeout(sock_read=0.2) client = await aiohttp_client(app, timeout=timeout) async with client.get("/") as resp: assert (await resp.content.read(5)) == b"data\n" with pytest.raises(asyncio.TimeoutError): await resp.content.read() async def test_read_timeout_on_write(aiohttp_client: AiohttpClient) -> None: async def gen_payload() -> AsyncIterator[bytes]: # Delay writing to ensure read timeout isn't triggered before writing completes. await asyncio.sleep(0.5) yield b"foo" async def handler(request: web.Request) -> web.Response: return web.Response(body=await request.read()) app = web.Application() app.router.add_put("/", handler) timeout = aiohttp.ClientTimeout(total=None, sock_read=0.1) client = await aiohttp_client(app) async with client.put("/", data=gen_payload(), timeout=timeout) as resp: result = await resp.read() # Should not trigger a read timeout. assert result == b"foo" async def test_timeout_on_reading_data( aiohttp_client: AiohttpClient, mocker: MockerFixture ) -> None: loop = asyncio.get_event_loop() fut = loop.create_future() async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse(headers={"content-length": "100"}) await resp.prepare(request) fut.set_result(None) await asyncio.sleep(0.2) return resp app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/", timeout=aiohttp.ClientTimeout(1)) as resp: await fut with pytest.raises(asyncio.TimeoutError): await resp.read() async def test_timeout_none( aiohttp_client: AiohttpClient, mocker: MockerFixture ) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() await resp.prepare(request) return resp app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.get("/", timeout=None) as resp: assert resp.status == 200 async def test_connection_timeout_error( aiohttp_client: AiohttpClient, mocker: MockerFixture ) -> None: """Test that ConnectionTimeoutError is raised when connection times out.""" async def handler(request: web.Request) -> NoReturn: assert False, "Handler should not be called" app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) # Mock the connector's connect method to raise asyncio.TimeoutError mock_connect = mocker.patch.object( client.session._connector, "connect", side_effect=asyncio.TimeoutError() ) with pytest.raises(aiohttp.ConnectionTimeoutError) as exc_info: await client.get("/", timeout=aiohttp.ClientTimeout(connect=0.01)) assert "Connection timeout to host" in str(exc_info.value) mock_connect.assert_called_once() async def test_readline_error_on_conn_close(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() async def handler(request: web.Request) -> NoReturn: resp = web.StreamResponse() await resp.prepare(request) # make sure connection is closed by client. with pytest.raises(aiohttp.ServerDisconnectedError): for _ in range(10): await resp.write(b"data\n") await asyncio.sleep(0.5) assert False app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_client(app) async with aiohttp.ClientSession() as session: timer_started = False url, headers = server.make_url("/"), {"Connection": "Keep-alive"} resp = await session.get(url, headers=headers) with pytest.raises(aiohttp.ClientConnectionError): while True: data = await resp.content.readline() data = data.strip() assert data assert data == b"data" if not timer_started: loop.call_later(1.0, resp.release) timer_started = True async def test_no_error_on_conn_close_if_eof(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp_ = web.StreamResponse() await resp_.prepare(request) await resp_.write(b"data\n") await asyncio.sleep(0.5) return resp_ app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_client(app) async with aiohttp.ClientSession() as session: url, headers = server.make_url("/"), {"Connection": "Keep-alive"} resp = await session.get(url, headers=headers) while True: data = await resp.content.readline() data = data.strip() if not data: break assert data == b"data" assert resp.content.exception() is None async def test_error_not_overwrote_on_conn_close(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp_ = web.StreamResponse() await resp_.prepare(request) return resp_ app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_client(app) async with aiohttp.ClientSession() as session: url, headers = server.make_url("/"), {"Connection": "Keep-alive"} resp = await session.get(url, headers=headers) resp.content.set_exception(ValueError()) assert isinstance(resp.content.exception(), ValueError) async def test_HTTP_200_OK_METHOD(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) app = web.Application() for meth in ("get", "post", "put", "delete", "head", "patch", "options"): app.router.add_route(meth.upper(), "/", handler) client = await aiohttp_client(app) for meth in ("get", "post", "put", "delete", "head", "patch", "options"): async with client.request(meth, "/") as resp: assert resp.status == 200 assert len(resp.history) == 0 content1 = await resp.read() content2 = await resp.read() assert content1 == content2 content = await resp.text() if meth == "head": assert b"" == content1 else: assert meth.upper() == content async def test_HTTP_200_OK_METHOD_connector(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) conn = aiohttp.TCPConnector() conn.clear_dns_cache() app = web.Application() for meth in ("get", "post", "put", "delete", "head"): app.router.add_route(meth.upper(), "/", handler) client = await aiohttp_client(app, connector=conn) for meth in ("get", "post", "put", "delete", "head"): async with client.request(meth, "/") as resp: content1 = await resp.read() content2 = await resp.read() assert content1 == content2 content = await resp.text() assert resp.status == 200 if meth == "head": assert b"" == content1 else: assert meth.upper() == content async def test_HTTP_302_REDIRECT_GET(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) async def redirect(request: web.Request) -> NoReturn: raise web.HTTPFound(location="/") app = web.Application() app.router.add_get("/", handler) app.router.add_get("/redirect", redirect) client = await aiohttp_client(app) async with client.get("/redirect") as resp: assert 200 == resp.status assert 1 == len(resp.history) async def test_HTTP_302_REDIRECT_HEAD(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) async def redirect(request: web.Request) -> NoReturn: raise web.HTTPFound(location="/") app = web.Application() app.router.add_get("/", handler) app.router.add_get("/redirect", redirect) app.router.add_head("/", handler) app.router.add_head("/redirect", redirect) client = await aiohttp_client(app) async with client.request("head", "/redirect") as resp: assert 200 == resp.status assert 1 == len(resp.history) assert resp.method == "HEAD" async def test_HTTP_302_REDIRECT_NON_HTTP(aiohttp_client: AiohttpClient) -> None: async def redirect(request: web.Request) -> NoReturn: raise web.HTTPFound(location="ftp://127.0.0.1/test/") app = web.Application() app.router.add_get("/redirect", redirect) client = await aiohttp_client(app) with pytest.raises(NonHttpUrlRedirectClientError): await client.get("/redirect") async def test_HTTP_302_REDIRECT_POST(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) async def redirect(request: web.Request) -> NoReturn: raise web.HTTPFound(location="/") app = web.Application() app.router.add_get("/", handler) app.router.add_post("/redirect", redirect) client = await aiohttp_client(app) async with client.post("/redirect") as resp: assert resp.status == 200 assert 1 == len(resp.history) txt = await resp.text() assert txt == "GET" async def test_HTTP_302_REDIRECT_POST_with_content_length_hdr( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) async def redirect(request: web.Request) -> NoReturn: await request.read() raise web.HTTPFound(location="/") data = json.dumps({"some": "data"}) app = web.Application() app.router.add_get("/", handler) app.router.add_post("/redirect", redirect) client = await aiohttp_client(app) async with client.post( "/redirect", data=data, headers={"Content-Length": str(len(data))} ) as resp: assert resp.status == 200 assert 1 == len(resp.history) txt = await resp.text() assert txt == "GET" async def test_HTTP_307_REDIRECT_POST(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) async def redirect(request: web.Request) -> NoReturn: await request.read() raise web.HTTPTemporaryRedirect(location="/") app = web.Application() app.router.add_post("/", handler) app.router.add_post("/redirect", redirect) client = await aiohttp_client(app) async with client.post("/redirect", data={"some": "data"}) as resp: assert resp.status == 200 assert 1 == len(resp.history) txt = await resp.text() assert txt == "POST" async def test_HTTP_308_PERMANENT_REDIRECT_POST(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) async def redirect(request: web.Request) -> NoReturn: await request.read() raise web.HTTPPermanentRedirect(location="/") app = web.Application() app.router.add_post("/", handler) app.router.add_post("/redirect", redirect) client = await aiohttp_client(app) async with client.post("/redirect", data={"some": "data"}) as resp: assert resp.status == 200 assert 1 == len(resp.history) txt = await resp.text() assert txt == "POST" async def test_HTTP_302_max_redirects(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: assert False async def redirect(request: web.Request) -> NoReturn: count = int(request.match_info["count"]) assert count raise web.HTTPFound(location=f"/redirect/{count - 1}") app = web.Application() app.router.add_get("/", handler) app.router.add_get(r"/redirect/{count:\d+}", redirect) client = await aiohttp_client(app) with pytest.raises(TooManyRedirects) as ctx: await client.get("/redirect/5", max_redirects=2) assert 2 == len(ctx.value.history) assert ctx.value.request_info.url.path == "/redirect/5" assert ctx.value.request_info.method == "GET" async def test_HTTP_200_GET_WITH_PARAMS(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response( text="&".join(k + "=" + v for k, v in request.query.items()) ) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/", params={"q": "test"}) as resp: assert resp.status == 200 txt = await resp.text() assert txt == "q=test" async def test_HTTP_200_GET_WITH_MultiDict_PARAMS( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response( text="&".join(k + "=" + v for k, v in request.query.items()) ) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) params = MultiDict([("q", "test"), ("q", "test2")]) async with client.get("/", params=params) as resp: assert resp.status == 200 txt = await resp.text() assert txt == "q=test&q=test2" async def test_HTTP_200_GET_WITH_MIXED_PARAMS(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response( text="&".join(k + "=" + v for k, v in request.query.items()) ) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/?test=true", params={"q": "test"}) as resp: assert resp.status == 200 txt = await resp.text() assert txt == "test=true&q=test" async def test_POST_DATA(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() return web.json_response(dict(data)) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) async with client.post("/", data={"some": "data"}) as resp: assert resp.status == 200 content = await resp.json() assert content == {"some": "data"} async def test_POST_DATA_with_explicit_formdata(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() return web.json_response(dict(data)) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) form = aiohttp.FormData() form.add_field("name", "text") async with client.post("/", data=form) as resp: assert resp.status == 200 content = await resp.json() assert content == {"name": "text"} async def test_POST_DATA_with_charset(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: mp = await request.multipart() part = await mp.next() assert isinstance(part, aiohttp.BodyPartReader) text = await part.text() return web.Response(text=text) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) form = aiohttp.FormData() form.add_field("name", "текст", content_type="text/plain; charset=koi8-r") async with client.post("/", data=form) as resp: assert resp.status == 200 content = await resp.text() assert content == "текст" async def test_POST_DATA_formdats_with_charset(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: mp = await request.post() assert "name" in mp assert isinstance(mp["name"], str) return web.Response(text=mp["name"]) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) form = aiohttp.FormData(charset="koi8-r") form.add_field("name", "текст") async with client.post("/", data=form) as resp: assert resp.status == 200 content = await resp.text() assert content == "текст" async def test_POST_DATA_with_charset_post(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() assert isinstance(data["name"], str) return web.Response(text=data["name"]) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) form = aiohttp.FormData() form.add_field("name", "текст", content_type="text/plain; charset=koi8-r") async with client.post("/", data=form) as resp: assert resp.status == 200 content = await resp.text() assert content == "текст" async def test_POST_MultiDict(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() assert data == MultiDict([("q", "test1"), ("q", "test2")]) return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) async with client.post( "/", data=MultiDict([("q", "test1"), ("q", "test2")]) ) as resp: assert 200 == resp.status async def test_GET_DEFLATE(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.json_response({"ok": True}) with mock.patch.object( ClientRequest, "_write_bytes", autospec=True, spec_set=True ) as m: app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/", data=b"", compress=True) as resp: assert resp.status == 200 content = await resp.json() assert content == {"ok": True} # With an empty body, _write_bytes() should not be called at all. m.assert_not_called() async def test_GET_DEFLATE_no_body(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.json_response({"ok": True}) with mock.patch.object(ClientRequest, "_write_bytes") as mock_write_bytes: app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/", data=None, compress=True) as resp: assert resp.status == 200 content = await resp.json() assert content == {"ok": True} # No chunks should have been sent for an empty body. mock_write_bytes.assert_not_called() async def test_POST_DATA_DEFLATE(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() return web.json_response(dict(data)) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) # True is not a valid type, but still tested for backwards compatibility. async with client.post("/", data={"some": "data"}, compress=True) as resp: assert resp.status == 200 content = await resp.json() assert content == {"some": "data"} async def test_POST_FILES(aiohttp_client: AiohttpClient, fname: pathlib.Path) -> None: content1 = fname.read_bytes() async def handler(request: web.Request) -> web.Response: data = await request.post() assert isinstance(data["some"], web.FileField) assert data["some"].filename == fname.name content2 = await asyncio.to_thread(data["some"].file.read) assert content2 == content1 assert isinstance(data["test"], web.FileField) assert await asyncio.to_thread(data["test"].file.read) == b"data" assert isinstance(data["some"], web.FileField) data["some"].file.close() data["test"].file.close() return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post( "/", data={"some": f, "test": io.BytesIO(b"data")}, chunked=True ) as resp: assert 200 == resp.status async def test_POST_FILES_DEFLATE( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content1 = fname.read_bytes() async def handler(request: web.Request) -> web.Response: data = await request.post() assert isinstance(data["some"], web.FileField) assert data["some"].filename == fname.name content2 = await asyncio.to_thread(data["some"].file.read) data["some"].file.close() assert content2 == content1 return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post( "/", data={"some": f}, chunked=True, compress="deflate" ) as resp: assert 200 == resp.status async def test_POST_bytes(aiohttp_client: AiohttpClient) -> None: body = b"0" * 12345 async def handler(request: web.Request) -> web.Response: data = await request.read() assert body == data return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) async with client.post("/", data=body) as resp: assert 200 == resp.status async def test_POST_bytes_too_large(aiohttp_client: AiohttpClient) -> None: body = b"0" * (2**20 + 1) async def handler(request: web.Request) -> web.Response: data = await request.content.read() assert body == data return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with pytest.warns(ResourceWarning): async with client.post("/", data=body) as resp: assert resp.status == 200 async def test_POST_FILES_STR( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content1 = fname.read_bytes().decode() async def handler(request: web.Request) -> web.Response: data = await request.post() content2 = data["some"] assert content2 == content1 return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post("/", data={"some": f.read().decode()}) as resp: assert 200 == resp.status async def test_POST_FILES_STR_SIMPLE( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content = fname.read_bytes() async def handler(request: web.Request) -> web.Response: data = await request.read() assert data == content return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post("/", data=f.read()) as resp: assert 200 == resp.status async def test_POST_FILES_LIST( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content = fname.read_bytes() async def handler(request: web.Request) -> web.Response: data = await request.post() assert isinstance(data["some"], web.FileField) assert fname.name == data["some"].filename assert await asyncio.to_thread(data["some"].file.read) == content data["some"].file.close() return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post("/", data=[("some", f)]) as resp: assert 200 == resp.status async def test_POST_FILES_CT( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content = fname.read_bytes() async def handler(request: web.Request) -> web.Response: data = await request.post() assert isinstance(data["some"], web.FileField) assert fname.name == data["some"].filename assert "text/plain" == data["some"].content_type assert await asyncio.to_thread(data["some"].file.read) == content data["some"].file.close() return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: form = aiohttp.FormData() form.add_field("some", f, content_type="text/plain") async with client.post("/", data=form) as resp: assert 200 == resp.status async def test_POST_FILES_SINGLE( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content = fname.read_bytes().decode() async def handler(request: web.Request) -> web.Response: data = await request.text() assert data == content # if system cannot determine 'text/x-python' MIME type # then use 'application/octet-stream' default assert request.content_type in [ "text/plain", "application/octet-stream", "text/x-python", ] assert "content-disposition" not in request.headers return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post("/", data=f) as resp: assert 200 == resp.status async def test_POST_FILES_SINGLE_content_disposition( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content = fname.read_bytes().decode() async def handler(request: web.Request) -> web.Response: data = await request.text() assert data == content # if system cannot determine 'application/pgp-keys' MIME type # then use 'application/octet-stream' default assert request.content_type in [ "text/plain", "application/octet-stream", "text/x-python", ] assert request.headers["content-disposition"] == ( 'inline; filename="conftest.py"' ) return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post( "/", data=aiohttp.get_payload(f, disposition="inline") ) as resp: assert 200 == resp.status async def test_POST_FILES_SINGLE_BINARY( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content = fname.read_bytes() async def handler(request: web.Request) -> web.Response: data = await request.read() assert data == content # if system cannot determine 'application/pgp-keys' MIME type # then use 'application/octet-stream' default assert request.content_type in [ "application/pgp-keys", "text/plain", "text/x-python", "application/octet-stream", ] return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post("/", data=f) as resp: assert 200 == resp.status async def test_POST_FILES_IO(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() assert isinstance(data["unknown"], web.FileField) assert b"data" == await asyncio.to_thread(data["unknown"].file.read) assert data["unknown"].content_type == "application/octet-stream" assert data["unknown"].filename == "unknown" data["unknown"].file.close() return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with io.BytesIO(b"data") as file_handle: async with client.post("/", data=[file_handle]) as resp: assert 200 == resp.status async def test_POST_FILES_IO_WITH_PARAMS(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() assert data["test"] == "true" assert isinstance(data["unknown"], web.FileField) assert data["unknown"].content_type == "application/octet-stream" assert data["unknown"].filename == "unknown" assert await asyncio.to_thread(data["unknown"].file.read) == b"data" data["unknown"].file.close() assert data.getall("q") == ["t1", "t2"] return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with io.BytesIO(b"data") as file_handle: async with client.post( "/", data=(("test", "true"), MultiDict([("q", "t1"), ("q", "t2")]), file_handle), ) as resp: assert 200 == resp.status async def test_POST_FILES_WITH_DATA( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: content = fname.read_bytes() async def handler(request: web.Request) -> web.Response: data = await request.post() assert data["test"] == "true" assert isinstance(data["some"], web.FileField) assert data["some"].content_type in [ "text/x-python", "text/plain", "application/octet-stream", ] assert data["some"].filename == fname.name assert await asyncio.to_thread(data["some"].file.read) == content data["some"].file.close() return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: async with client.post("/", data={"test": "true", "some": f}) as resp: assert 200 == resp.status async def test_POST_STREAM_DATA( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: expected = fname.read_bytes() async def handler(request: web.Request) -> web.Response: assert request.content_type == "application/octet-stream" content = await request.read() assert request.content_length == len(expected) assert content == expected return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) data_size = len(expected) async def gen(fname: pathlib.Path) -> AsyncIterator[bytes]: with fname.open("rb") as f: data = await asyncio.to_thread(f.read, 100) while data: yield data data = await asyncio.to_thread(f.read, 100) async with client.post( "/", data=gen(fname), headers={"Content-Length": str(data_size)} ) as resp: assert 200 == resp.status async def test_json(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.content_type == "application/json" data = await request.json() return web.Response(body=aiohttp.JsonPayload(data)) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) async with client.post("/", json={"some": "data"}) as resp: assert resp.status == 200 content = await resp.json() assert content == {"some": "data"} with pytest.raises(ValueError): await client.post("/", data="some data", json={"some": "data"}) async def test_json_custom(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.content_type == "application/json" data = await request.json() return web.Response(body=aiohttp.JsonPayload(data)) used = False def dumps(obj: Any) -> str: nonlocal used used = True return json.dumps(obj) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app, json_serialize=dumps) async with client.post("/", json={"some": "data"}) as resp: assert resp.status == 200 assert used content = await resp.json() assert content == {"some": "data"} with pytest.raises(ValueError): await client.post("/", data="some data", json={"some": "data"}) async def test_json_serialize_bytes(aiohttp_client: AiohttpClient) -> None: """Test ClientSession.json_serialize_bytes with bytes-returning encoder.""" async def handler(request: web.Request) -> web.Response: assert request.content_type == "application/json" data = await request.json() return web.Response(body=aiohttp.JsonPayload(data)) json_bytes_encoder = mock.Mock(side_effect=lambda x: json.dumps(x).encode("utf-8")) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app, json_serialize_bytes=json_bytes_encoder) async with client.post("/", json={"some": "data"}) as resp: assert resp.status == 200 assert json_bytes_encoder.called content = await resp.json() assert content == {"some": "data"} async def test_expect_continue(aiohttp_client: AiohttpClient) -> None: expect_called = False async def handler(request: web.Request) -> web.Response: data = await request.post() assert data == {"some": "data"} return web.Response() async def expect_handler(request: web.Request) -> None: nonlocal expect_called expect = request.headers[hdrs.EXPECT] assert expect.lower() == "100-continue" assert request.transport is not None request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") expect_called = True app = web.Application() app.router.add_post("/", handler, expect_handler=expect_handler) client = await aiohttp_client(app) async with client.post("/", data={"some": "data"}, expect100=True) as resp: assert 200 == resp.status assert expect_called async def test_expect100_with_no_body(aiohttp_client: AiohttpClient) -> None: """Test expect100 with GET request that has no body.""" async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) # GET request with expect100=True but no body async with client.get("/", expect100=True) as resp: assert resp.status == 200 assert await resp.text() == "OK" async def test_expect100_continue_with_none_payload( aiohttp_client: AiohttpClient, ) -> None: """Test expect100 continue handling when payload is None from the start.""" expect_received = False async def handler(request: web.Request) -> web.Response: return web.Response(body=b"OK") async def expect_handler(request: web.Request) -> None: nonlocal expect_received expect_received = True # Send 100 Continue assert request.transport is not None request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") app = web.Application() app.router.add_post("/", handler, expect_handler=expect_handler) client = await aiohttp_client(app) # POST request with expect100=True but no body (data=None) async with client.post("/", expect100=True, data=None) as resp: assert resp.status == 200 assert await resp.read() == b"OK" # Expect handler should still be called even with no body assert expect_received @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_encoding_deflate( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_chunked_encoding() resp.enable_compression(web.ContentCoding.deflate) return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 txt = await resp.text() assert txt == "text" @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_encoding_deflate_nochunk( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_compression(web.ContentCoding.deflate) return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 txt = await resp.text() assert txt == "text" @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_encoding_gzip( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_chunked_encoding() resp.enable_compression(web.ContentCoding.gzip) return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 txt = await resp.text() assert txt == "text" @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_encoding_gzip_write_by_chunks( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(request) await resp.write(b"0") await resp.write(b"0") return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 txt = await resp.text() assert txt == "00" @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_encoding_gzip_nochunk( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_compression(web.ContentCoding.gzip) return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 txt = await resp.text() assert txt == "text" async def test_bad_payload_compression(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.headers["Content-Encoding"] = "gzip" return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 with pytest.raises(aiohttp.ClientPayloadError): await resp.read() async def test_payload_decompress_size_limit(aiohttp_client: AiohttpClient) -> None: """Test that decompression size limit triggers DecompressSizeError. When a compressed payload expands beyond the configured limit, we raise DecompressSizeError. """ # Create a highly compressible payload that exceeds the decompression limit. # 64MiB of repeated bytes compresses to ~32KB but expands beyond the # 32MiB per-call limit. original = b"A" * (64 * 2**20) compressed = zlib.compress(original) assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE async def handler(request: web.Request) -> web.Response: # Send compressed data with Content-Encoding header resp = web.Response(body=compressed) resp.headers["Content-Encoding"] = "deflate" return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 with pytest.raises(aiohttp.ClientPayloadError) as exc_info: await resp.read() assert isinstance(exc_info.value.__cause__, DecompressSizeError) assert "Decompressed data exceeds" in str(exc_info.value.__cause__) @pytest.mark.skipif(brotli is None, reason="brotli is not installed") async def test_payload_decompress_size_limit_brotli( aiohttp_client: AiohttpClient, ) -> None: """Test that brotli decompression size limit triggers DecompressSizeError.""" assert brotli is not None # Create a highly compressible payload that exceeds the decompression limit. original = b"A" * (64 * 2**20) compressed = brotli.compress(original) assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE async def handler(request: web.Request) -> web.Response: resp = web.Response(body=compressed) resp.headers["Content-Encoding"] = "br" return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 with pytest.raises(aiohttp.ClientPayloadError) as exc_info: await resp.read() assert isinstance(exc_info.value.__cause__, DecompressSizeError) assert "Decompressed data exceeds" in str(exc_info.value.__cause__) @pytest.mark.skipif(ZstdCompressor is None, reason="backports.zstd is not installed") async def test_payload_decompress_size_limit_zstd( aiohttp_client: AiohttpClient, ) -> None: """Test that zstd decompression size limit triggers DecompressSizeError.""" assert ZstdCompressor is not None # Create a highly compressible payload that exceeds the decompression limit. original = b"A" * (64 * 2**20) compressor = ZstdCompressor() compressed = compressor.compress(original) + compressor.flush() assert len(original) > DEFAULT_MAX_DECOMPRESS_SIZE async def handler(request: web.Request) -> web.Response: resp = web.Response(body=compressed) resp.headers["Content-Encoding"] = "zstd" return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 with pytest.raises(aiohttp.ClientPayloadError) as exc_info: await resp.read() assert isinstance(exc_info.value.__cause__, DecompressSizeError) assert "Decompressed data exceeds" in str(exc_info.value.__cause__) async def test_bad_payload_chunked_encoding(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() resp.force_close() resp._length_check = False resp.headers["Transfer-Encoding"] = "chunked" writer = await resp.prepare(request) assert writer is not None await writer.write(b"9\r\n\r\n") await writer.write_eof() return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 with pytest.raises(aiohttp.ClientPayloadError): await resp.read() async def test_no_payload_304_with_chunked_encoding( aiohttp_client: AiohttpClient, ) -> None: """Test a 304 response with no payload with chunked set should have it removed.""" async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse(status=304) resp.enable_chunked_encoding() resp._length_check = False resp.headers["Transfer-Encoding"] = "chunked" writer = await resp.prepare(request) assert writer is not None await writer.write_eof() return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 304 assert hdrs.CONTENT_LENGTH not in resp.headers assert hdrs.TRANSFER_ENCODING not in resp.headers await resp.read() async def test_head_request_with_chunked_encoding( aiohttp_client: AiohttpClient, ) -> None: """Test a head response with chunked set should have it removed.""" async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse(status=200) resp.enable_chunked_encoding() resp._length_check = False resp.headers["Transfer-Encoding"] = "chunked" writer = await resp.prepare(request) assert writer is not None await writer.write_eof() return resp app = web.Application() app.router.add_head("/", handler) client = await aiohttp_client(app) async with client.head("/") as resp: assert resp.status == 200 assert hdrs.CONTENT_LENGTH not in resp.headers assert hdrs.TRANSFER_ENCODING not in resp.headers await resp.read() async def test_no_payload_200_with_chunked_encoding( aiohttp_client: AiohttpClient, ) -> None: """Test chunked is preserved on a 200 response with no payload.""" async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse(status=200) resp.enable_chunked_encoding() resp._length_check = False resp.headers["Transfer-Encoding"] = "chunked" writer = await resp.prepare(request) assert writer is not None await writer.write_eof() return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 assert hdrs.CONTENT_LENGTH not in resp.headers assert hdrs.TRANSFER_ENCODING in resp.headers await resp.read() async def test_bad_payload_content_length(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.headers["Content-Length"] = "10000" resp.force_close() return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert 200 == resp.status with pytest.raises(aiohttp.ClientPayloadError): await resp.read() async def test_payload_content_length_by_chunks(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse(headers={"content-length": "2"}) await resp.prepare(request) await resp.write(b"answer") await resp.write(b"two") assert request.transport is not None request.transport.close() return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: data = await resp.read() assert data == b"an" async def test_chunked(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: resp = web.Response(text="text") resp.enable_chunked_encoding() return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 assert resp.headers["Transfer-Encoding"] == "chunked" txt = await resp.text() assert txt == "text" async def test_shortcuts(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) app = web.Application() for meth in ("get", "post", "put", "delete", "head", "patch", "options"): app.router.add_route(meth.upper(), "/", handler) client = await aiohttp_client(app) for meth in ("get", "post", "put", "delete", "head", "patch", "options"): coro = getattr(client.session, meth) resp = await coro(client.make_url("/")) assert resp.status == 200 assert len(resp.history) == 0 content1 = await resp.read() content2 = await resp.read() assert content1 == content2 content = await resp.text() if meth == "head": assert b"" == content1 else: assert meth.upper() == content async def test_cookies(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.cookies.keys() == {"test1", "test3"} assert request.cookies["test1"] == "123" assert request.cookies["test3"] == "456" return web.Response() c: http.cookies.Morsel[str] = http.cookies.Morsel() c.set("test3", "456", "456") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, cookies={"test1": "123", "test2": c}) async with client.get("/") as resp: assert 200 == resp.status async def test_cookies_per_request(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.cookies.keys() == {"test1", "test3", "test4", "test6"} assert request.cookies["test1"] == "123" assert request.cookies["test3"] == "456" assert request.cookies["test4"] == "789" assert request.cookies["test6"] == "abc" return web.Response() c: http.cookies.Morsel[str] = http.cookies.Morsel() c.set("test3", "456", "456") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, cookies={"test1": "123", "test2": c}) rc: http.cookies.Morsel[str] = http.cookies.Morsel() rc.set("test6", "abc", "abc") cookies: dict[str, str | http.cookies.Morsel[str]] cookies = {"test4": "789", "test5": rc} async with client.get("/", cookies=cookies) as resp: assert 200 == resp.status async def test_cookies_redirect(aiohttp_client: AiohttpClient) -> None: async def redirect1(request: web.Request) -> web.Response: ret = web.Response(status=301, headers={"Location": "/redirect2"}) ret.set_cookie("c", "1") return ret async def redirect2(request: web.Request) -> web.Response: ret = web.Response(status=301, headers={"Location": "/"}) ret.set_cookie("c", "2") return ret async def handler(request: web.Request) -> web.Response: assert request.cookies.keys() == {"c"} assert request.cookies["c"] == "2" return web.Response() app = web.Application() app.router.add_get("/redirect1", redirect1) app.router.add_get("/redirect2", redirect2) app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/redirect1") as resp: assert 200 == resp.status async def test_cookies_on_empty_session_jar(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert "custom-cookie" in request.cookies assert request.cookies["custom-cookie"] == "abc" return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, cookies=None) async with client.get("/", cookies={"custom-cookie": "abc"}) as resp: assert 200 == resp.status async def test_cookies_is_quoted_with_special_characters( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: assert 'cookie1="val/one"' == request.headers["Cookie"] assert "cookie1" in request.cookies assert request.cookies["cookie1"] == "val/one" return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/", cookies={"cookie1": "val/one"}) as resp: assert resp.status == 200 async def test_morsel_with_attributes(aiohttp_client: AiohttpClient) -> None: # A comment from original test: # # No cookie attribute should pass here # they are only used as filters # whether to send particular cookie or not. # E.g. if cookie expires it just becomes thrown away. # Server who sent the cookie with some attributes # already knows them, no need to send this back again and again async def handler(request: web.Request) -> web.Response: assert request.cookies.keys() == {"test3"} assert request.cookies["test3"] == "456" return web.Response() c: http.cookies.Morsel[str] = http.cookies.Morsel() c.set("test3", "456", "456") c["httponly"] = True c["secure"] = True c["max-age"] = 1000 app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, cookies={"test2": c}) async with client.get("/") as resp: assert 200 == resp.status async def test_set_cookies( aiohttp_client: AiohttpClient, caplog: pytest.LogCaptureFixture ) -> None: async def handler(request: web.Request) -> web.Response: ret = web.Response() ret.set_cookie("c1", "cookie1") ret.set_cookie("c2", "cookie2") ret.headers.add( "Set-Cookie", "invalid,cookie=value; " # Comma character is not allowed "HttpOnly; Path=/", ) return ret app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) with caplog.at_level(logging.WARNING): async with client.get("/") as resp: assert 200 == resp.status cookie_names = {c.key for c in client.session.cookie_jar} _ = resp.cookies assert cookie_names == {"c1", "c2"} assert "Can not load cookies: Illegal cookie name 'invalid,cookie'" in caplog.text async def test_set_cookies_with_curly_braces(aiohttp_client: AiohttpClient) -> None: """Test that cookies with curly braces in names are now accepted (#2683).""" async def handler(request: web.Request) -> web.Response: ret = web.Response() ret.set_cookie("c1", "cookie1") ret.headers.add( "Set-Cookie", "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=" "{925EC0B8-CB17-4BEB-8A35-1033813B0523}; " "HttpOnly; Path=/", ) return ret app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert 200 == resp.status cookie_names = {c.key for c in client.session.cookie_jar} assert cookie_names == {"c1", "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}"} async def test_set_cookies_expired(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: ret = web.Response() ret.set_cookie("c1", "cookie1") ret.set_cookie("c2", "cookie2") ret.headers.add( "Set-Cookie", "c3=cookie3; HttpOnly; Path=/ Expires=Tue, 1 Jan 1980 12:00:00 GMT; ", ) return ret app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert 200 == resp.status cookie_names = {c.key for c in client.session.cookie_jar} assert cookie_names == {"c1", "c2"} async def test_set_cookies_max_age(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: ret = web.Response() ret.set_cookie("c1", "cookie1") ret.set_cookie("c2", "cookie2") ret.headers.add("Set-Cookie", "c3=cookie3; HttpOnly; Path=/ Max-Age=1; ") return ret app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert 200 == resp.status cookie_names = {c.key for c in client.session.cookie_jar} assert cookie_names == {"c1", "c2", "c3"} await asyncio.sleep(2) cookie_names = {c.key for c in client.session.cookie_jar} assert cookie_names == {"c1", "c2"} async def test_set_cookies_max_age_overflow(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: ret = web.Response() ret.headers.add( "Set-Cookie", "overflow=overflow; HttpOnly; Path=/ Max-Age=" + str(overflow) + "; ", ) return ret overflow = int( datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp() ) empty = None try: empty = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta( seconds=overflow ) except OverflowError as ex: assert isinstance(ex, OverflowError) assert not isinstance(empty, datetime.datetime) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert 200 == resp.status for cookie in client.session.cookie_jar: assert cookie.key == "overflow" assert int(cookie["max-age"]) == int(overflow) async def test_request_conn_error() -> None: async with aiohttp.ClientSession() as client: with pytest.raises(aiohttp.ClientConnectionError): await client.get("http://0.0.0.0:1") @pytest.mark.xfail async def test_broken_connection(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.transport is not None request.transport.close() return web.Response(text="answer" * 1000) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) with pytest.raises(aiohttp.ClientResponseError): await client.get("/") async def test_broken_connection_2(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse(headers={"content-length": "1000"}) await resp.prepare(request) await resp.write(b"answer") assert request.transport is not None request.transport.close() return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: with pytest.raises(aiohttp.ClientPayloadError): await resp.read() async def test_custom_headers(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.headers["x-api-key"] == "foo" return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) async with client.post( "/", headers={"Content-Type": "application/json", "x-api-key": "foo"} ) as resp: assert resp.status == 200 async def test_redirect_to_absolute_url(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text=request.method) async def redirect(request: web.Request) -> NoReturn: raise web.HTTPFound(location=client.make_url("/")) app = web.Application() app.router.add_get("/", handler) app.router.add_get("/redirect", redirect) client = await aiohttp_client(app) async with client.get("/redirect") as resp: assert 200 == resp.status async def test_redirect_without_location_header(aiohttp_client: AiohttpClient) -> None: body = b"redirect" async def handler_redirect(request: web.Request) -> web.Response: return web.Response(status=301, body=body) app = web.Application() app.router.add_route("GET", "/redirect", handler_redirect) client = await aiohttp_client(app) async with client.get("/redirect") as resp: data = await resp.read() assert data == body INVALID_URL_WITH_ERROR_MESSAGE_YARL_NEW = ( # yarl.URL.__new__ raises ValueError ("http://:/", "http://:/"), ("http://example.org:non_int_port/", "http://example.org:non_int_port/"), ) INVALID_URL_WITH_ERROR_MESSAGE_YARL_ORIGIN = ( # # yarl.URL.origin raises ValueError ("http:/", "http:///"), ("http:/example.com", "http:///example.com"), ("http:///example.com", "http:///example.com"), ) NON_HTTP_URL_WITH_ERROR_MESSAGE = ( ("call:+380123456789", r"call:\+380123456789"), ("skype:handle", "skype:handle"), ("slack://instance/room", "slack://instance/room"), ("steam:code", "steam:code"), ("twitter://handle", "twitter://handle"), ("bluesky://profile/d:i:d", "bluesky://profile/d:i:d"), ) @pytest.mark.parametrize( ("url", "error_message_url", "expected_exception_class"), ( *( (url, message, InvalidUrlClientError) for (url, message) in INVALID_URL_WITH_ERROR_MESSAGE_YARL_NEW ), *( (url, message, InvalidUrlClientError) for (url, message) in INVALID_URL_WITH_ERROR_MESSAGE_YARL_ORIGIN ), *( (url, message, NonHttpUrlClientError) for (url, message) in NON_HTTP_URL_WITH_ERROR_MESSAGE ), ), ) async def test_invalid_and_non_http_url( url: str, error_message_url: str, expected_exception_class: type[Exception] ) -> None: async with aiohttp.ClientSession() as http_session: with pytest.raises( expected_exception_class, match=rf"^{error_message_url}( - [A-Za-z ]+)?" ): await http_session.get(url) @pytest.mark.parametrize( ("invalid_redirect_url", "error_message_url", "expected_exception_class"), ( *( (url, message, InvalidUrlRedirectClientError) for (url, message) in INVALID_URL_WITH_ERROR_MESSAGE_YARL_ORIGIN + INVALID_URL_WITH_ERROR_MESSAGE_YARL_NEW ), *( (url, message, NonHttpUrlRedirectClientError) for (url, message) in NON_HTTP_URL_WITH_ERROR_MESSAGE ), ), ) async def test_invalid_redirect_url( aiohttp_client: AiohttpClient, invalid_redirect_url: str, error_message_url: str, expected_exception_class: type[Exception], ) -> None: headers = {hdrs.LOCATION: invalid_redirect_url} async def generate_redirecting_response(request: web.Request) -> web.Response: return web.Response(status=301, headers=headers) app = web.Application() app.router.add_get("/redirect", generate_redirecting_response) client = await aiohttp_client(app) with pytest.raises( expected_exception_class, match=rf"^{error_message_url}( - [A-Za-z ]+)?" ): await client.get("/redirect") @pytest.mark.parametrize( ("invalid_redirect_url", "error_message_url", "expected_exception_class"), ( *( (url, message, InvalidUrlRedirectClientError) for (url, message) in INVALID_URL_WITH_ERROR_MESSAGE_YARL_ORIGIN + INVALID_URL_WITH_ERROR_MESSAGE_YARL_NEW ), *( (url, message, NonHttpUrlRedirectClientError) for (url, message) in NON_HTTP_URL_WITH_ERROR_MESSAGE ), ), ) async def test_invalid_redirect_url_multiple_redirects( aiohttp_client: AiohttpClient, invalid_redirect_url: str, error_message_url: str, expected_exception_class: type[Exception], ) -> None: app = web.Application() for path, location in [ ("/redirect", "/redirect1"), ("/redirect1", "/redirect2"), ("/redirect2", invalid_redirect_url), ]: async def generate_redirecting_response(request: web.Request) -> web.Response: return web.Response(status=301, headers={hdrs.LOCATION: location}) app.router.add_get(path, generate_redirecting_response) client = await aiohttp_client(app) with pytest.raises( expected_exception_class, match=rf"^{error_message_url}( - [A-Za-z ]+)?" ): await client.get("/redirect") @pytest.mark.parametrize( ("status", "expected_ok"), ( (200, True), (201, True), (301, True), (400, False), (403, False), (500, False), ), ) async def test_ok_from_status( aiohttp_client: AiohttpClient, status: int, expected_ok: bool ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(status=status, body=b"") app = web.Application() app.router.add_route("GET", "/endpoint", handler) client = await aiohttp_client(app, raise_for_status=False) async with client.get("/endpoint") as resp: assert resp.ok is expected_ok async def test_raise_for_status(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: raise web.HTTPBadRequest() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app, raise_for_status=True) with pytest.raises(aiohttp.ClientResponseError): await client.get("/") async def test_raise_for_status_per_request(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: raise web.HTTPBadRequest() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) with pytest.raises(aiohttp.ClientResponseError): await client.get("/", raise_for_status=True) async def test_raise_for_status_disable_per_request( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: raise web.HTTPBadRequest() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app, raise_for_status=True) async with client.get("/", raise_for_status=False) as resp: assert 400 == resp.status async def test_request_raise_for_status_default(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: raise web.HTTPBadRequest() app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with aiohttp.request("GET", server.make_url("/")) as resp: assert resp.status == 400 async def test_request_raise_for_status_disabled(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: raise web.HTTPBadRequest() app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) url = server.make_url("/") async with aiohttp.request("GET", url, raise_for_status=False) as resp: assert resp.status == 400 async def test_request_raise_for_status_enabled(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: raise web.HTTPBadRequest() app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) url = server.make_url("/") with pytest.raises(aiohttp.ClientResponseError): async with aiohttp.request("GET", url, raise_for_status=True): assert False async def test_session_raise_for_status_coro(aiohttp_client: AiohttpClient) -> None: async def handle(request: web.Request) -> web.Response: return web.Response(text="ok") app = web.Application() app.router.add_route("GET", "/", handle) raise_for_status_called = 0 async def custom_r4s(response: aiohttp.ClientResponse) -> None: nonlocal raise_for_status_called raise_for_status_called += 1 assert response.status == 200 assert response.request_info.method == "GET" client = await aiohttp_client(app, raise_for_status=custom_r4s) await client.get("/") assert raise_for_status_called == 1 await client.get("/", raise_for_status=True) assert raise_for_status_called == 1 # custom_r4s not called again await client.get("/", raise_for_status=False) assert raise_for_status_called == 1 # custom_r4s not called again async def test_request_raise_for_status_coro(aiohttp_client: AiohttpClient) -> None: async def handle(request: web.Request) -> web.Response: return web.Response(text="ok") app = web.Application() app.router.add_route("GET", "/", handle) raise_for_status_called = 0 async def custom_r4s(response: aiohttp.ClientResponse) -> None: nonlocal raise_for_status_called raise_for_status_called += 1 assert response.status == 200 assert response.request_info.method == "GET" client = await aiohttp_client(app) await client.get("/", raise_for_status=custom_r4s) assert raise_for_status_called == 1 await client.get("/", raise_for_status=True) assert raise_for_status_called == 1 # custom_r4s not called again await client.get("/", raise_for_status=False) assert raise_for_status_called == 1 # custom_r4s not called again async def test_invalid_idna() -> None: async with aiohttp.ClientSession() as session: with pytest.raises(aiohttp.InvalidURL): await session.get("http://\u0080owhefopw.com") async def test_creds_in_auth_and_url() -> None: async with aiohttp.ClientSession() as session: with pytest.raises(ValueError): await session.get( "http://user:pass@example.com", auth=aiohttp.BasicAuth("user2", "pass2") ) async def test_creds_in_auth_and_redirect_url( create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]], ) -> None: """Verify that credentials in redirect URLs can and do override any previous credentials.""" url_from = URL("http://example.com") url_to = URL("http://user@example.com") redirected = False async def srv(request: web.Request) -> web.Response: nonlocal redirected assert request.host == url_from.host if not redirected: redirected = True raise web.HTTPMovedPermanently(url_to) return web.Response() server = await create_server_for_url_and_handler(url_from, srv) etc_hosts = { (url_from.host, 80): server, } class FakeResolver(AbstractResolver): async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, ) -> list[ResolveResult]: server = etc_hosts[(host, port)] assert server.port is not None return [ { "hostname": host, "host": server.host, "port": server.port, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } ] async def close(self) -> None: """Dummy""" connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False) async with ( aiohttp.ClientSession(connector=connector) as client, client.get(url_from, auth=aiohttp.BasicAuth("user", "pass")) as resp, ): assert len(resp.history) == 1 assert str(resp.url) == "http://example.com" assert resp.status == 200 assert ( resp.request_info.headers.get("authorization") == "Basic dXNlcjo=" ), "Expected redirect credentials to take precedence over provided auth" @pytest.fixture def create_server_for_url_and_handler( aiohttp_server: AiohttpServer, tls_certificate_authority: trustme.CA ) -> Callable[[URL, Handler], Awaitable[TestServer]]: def create(url: URL, srv: Handler) -> Awaitable[TestServer]: app = web.Application() app.router.add_route("GET", url.path, srv) if url.scheme == "https": assert url.host cert = tls_certificate_authority.issue_cert( url.host, "localhost", "127.0.0.1" ) ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) cert.configure_cert(ssl_ctx) return aiohttp_server(app, ssl=ssl_ctx) return aiohttp_server(app) return create @pytest.mark.parametrize( ["url_from_s", "url_to_s", "is_drop_header_expected"], [ [ "http://host1.com/path1", "http://host2.com/path2", True, ], ["http://host1.com/path1", "https://host1.com/path1", False], ["https://host1.com/path1", "http://host1.com/path2", True], ], ids=( "entirely different hosts", "http -> https", "https -> http", ), ) async def test_drop_auth_on_redirect_to_other_host( create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]], url_from_s: str, url_to_s: str, is_drop_header_expected: bool, ) -> None: url_from, url_to = URL(url_from_s), URL(url_to_s) async def srv_from(request: web.Request) -> NoReturn: assert request.host == url_from.host assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz" raise web.HTTPFound(url_to) async def srv_to(request: web.Request) -> web.Response: assert request.host == url_to.host if is_drop_header_expected: assert "Authorization" not in request.headers, "Header wasn't dropped" assert "Proxy-Authorization" not in request.headers assert "Cookie" not in request.headers else: assert "Authorization" in request.headers, "Header was dropped" assert "Proxy-Authorization" in request.headers assert "Cookie" in request.headers return web.Response() server_from = await create_server_for_url_and_handler(url_from, srv_from) server_to = await create_server_for_url_and_handler(url_to, srv_to) assert ( url_from.host != url_to.host or server_from.scheme != server_to.scheme ), "Invalid test case, host or scheme must differ" protocol_port_map = { "http": 80, "https": 443, } etc_hosts = { (url_from.host, protocol_port_map[server_from.scheme]): server_from, (url_to.host, protocol_port_map[server_to.scheme]): server_to, } class FakeResolver(AbstractResolver): async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, ) -> list[ResolveResult]: server = etc_hosts[(host, port)] assert server.port is not None return [ { "hostname": host, "host": server.host, "port": server.port, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } ] async def close(self) -> None: """Dummy""" connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False) async with aiohttp.ClientSession(connector=connector) as client: async with client.get( url_from, auth=aiohttp.BasicAuth("user", "pass"), headers={"Proxy-Authorization": "Basic dXNlcjpwYXNz", "Cookie": "a=b"}, ) as resp: assert resp.status == 200 async with client.get( url_from, headers={ "Authorization": "Basic dXNlcjpwYXNz", "Proxy-Authorization": "Basic dXNlcjpwYXNz", "Cookie": "a=b", }, ) as resp: assert resp.status == 200 async def test_auth_persist_on_redirect_to_other_host_with_global_auth( create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]], ) -> None: url_from = URL("http://host1.com/path1") url_to = URL("http://host2.com/path2") async def srv_from(request: web.Request) -> NoReturn: assert request.host == url_from.host assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz" raise web.HTTPFound(url_to) async def srv_to(request: web.Request) -> web.Response: assert request.host == url_to.host assert "Authorization" in request.headers, "Header was dropped" return web.Response() server_from = await create_server_for_url_and_handler(url_from, srv_from) server_to = await create_server_for_url_and_handler(url_to, srv_to) assert ( url_from.host != url_to.host or server_from.scheme != server_to.scheme ), "Invalid test case, host or scheme must differ" protocol_port_map = { "http": 80, "https": 443, } etc_hosts = { (url_from.host, protocol_port_map[server_from.scheme]): server_from, (url_to.host, protocol_port_map[server_to.scheme]): server_to, } class FakeResolver(AbstractResolver): async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, ) -> list[ResolveResult]: server = etc_hosts[(host, port)] assert server.port is not None return [ { "hostname": host, "host": server.host, "port": server.port, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } ] async def close(self) -> None: """Dummy""" connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False) async with aiohttp.ClientSession( connector=connector, auth=aiohttp.BasicAuth("user", "pass") ) as client: async with client.get(url_from) as resp: assert resp.status == 200 async def test_drop_auth_on_redirect_to_other_host_with_global_auth_and_base_url( create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]], ) -> None: url_from = URL("http://host1.com/path1") url_to = URL("http://host2.com/path2") async def srv_from(request: web.Request) -> NoReturn: assert request.host == url_from.host assert request.headers["Authorization"] == "Basic dXNlcjpwYXNz" raise web.HTTPFound(url_to) async def srv_to(request: web.Request) -> web.Response: assert request.host == url_to.host assert "Authorization" not in request.headers, "Header was not dropped" return web.Response() server_from = await create_server_for_url_and_handler(url_from, srv_from) server_to = await create_server_for_url_and_handler(url_to, srv_to) assert ( url_from.host != url_to.host or server_from.scheme != server_to.scheme ), "Invalid test case, host or scheme must differ" protocol_port_map = { "http": 80, "https": 443, } etc_hosts = { (url_from.host, protocol_port_map[server_from.scheme]): server_from, (url_to.host, protocol_port_map[server_to.scheme]): server_to, } class FakeResolver(AbstractResolver): async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, ) -> list[ResolveResult]: server = etc_hosts[(host, port)] assert server.port is not None return [ { "hostname": host, "host": server.host, "port": server.port, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } ] async def close(self) -> None: """Dummy""" connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False) async with aiohttp.ClientSession( connector=connector, base_url="http://host1.com", auth=aiohttp.BasicAuth("user", "pass"), ) as client: async with client.get("/path1") as resp: assert resp.status == 200 async def test_async_with_session() -> None: async with aiohttp.ClientSession() as session: pass assert session.closed async def test_session_close_awaitable() -> None: session = aiohttp.ClientSession() await session.close() assert session.closed async def test_close_resp_on_error_async_with_session( aiohttp_server: AiohttpServer, ) -> None: async def handler(request: web.Request) -> NoReturn: resp = web.StreamResponse(headers={"content-length": "100"}) await resp.prepare(request) await asyncio.sleep(0.1) assert False app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as session: with pytest.raises(RuntimeError): async with session.get(server.make_url("/")) as resp: resp.content.set_exception(RuntimeError()) await resp.read() assert session._connector is not None assert len(session._connector._conns) == 0 async def test_release_resp_on_normal_exit_from_cm( aiohttp_server: AiohttpServer, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as session: async with session.get(server.make_url("/")) as resp: await resp.read() assert session._connector is not None assert len(session._connector._conns) == 1 async def test_non_close_detached_session_on_error_cm( aiohttp_server: AiohttpServer, ) -> None: async def handler(request: web.Request) -> NoReturn: resp = web.StreamResponse(headers={"content-length": "100"}) await resp.prepare(request) await asyncio.sleep(0.1) assert False app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) session = aiohttp.ClientSession() cm = session.get(server.make_url("/")) assert not session.closed with pytest.raises(RuntimeError): async with cm as resp: resp.content.set_exception(RuntimeError()) await resp.read() assert not session.closed async def test_close_detached_session_on_non_existing_addr() -> None: class FakeResolver(AbstractResolver): async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, ) -> list[ResolveResult]: return [] async def close(self) -> None: """Dummy""" connector = aiohttp.TCPConnector(resolver=FakeResolver()) session = aiohttp.ClientSession(connector=connector) async with session: cm = session.get("http://non-existing.example.com") assert not session.closed with pytest.raises(Exception): await cm assert session.closed async def test_aiohttp_request_context_manager(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with aiohttp.request("GET", server.make_url("/")) as resp: await resp.read() assert resp.status == 200 async def test_aiohttp_request_ctx_manager_close_sess_on_error( ssl_ctx: ssl.SSLContext, aiohttp_server: AiohttpServer ) -> None: async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app, ssl=ssl_ctx) cm = aiohttp.request("GET", server.make_url("/")) with pytest.raises(aiohttp.ClientConnectionError): async with cm: pass assert cm._session.closed # Allow event loop to process transport cleanup # on Python < 3.11 await asyncio.sleep(0) async def test_aiohttp_request_ctx_manager_not_found() -> None: with pytest.raises(aiohttp.ClientConnectionError): async with aiohttp.request("GET", "http://wrong-dns-name.com"): assert False async def test_raising_client_connector_dns_error_on_dns_failure() -> None: """Verify that the exception raised when a DNS lookup fails is specific to DNS.""" with mock.patch( "aiohttp.connector.TCPConnector._resolve_host", autospec=True, spec_set=True ) as mock_resolve_host: mock_resolve_host.side_effect = OSError(None, "DNS lookup failed") with pytest.raises(aiohttp.ClientConnectorDNSError, match="DNS lookup failed"): async with aiohttp.request("GET", "http://wrong-dns-name.com"): assert False, "never executed" async def test_aiohttp_request_coroutine(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) not_an_awaitable = aiohttp.request("GET", server.make_url("/")) with pytest.raises( TypeError, match=( "^'_SessionRequestContextManager' object can't be awaited$" if sys.version_info >= (3, 14) else "^object _SessionRequestContextManager " "can't be used in 'await' expression$" ), ): await not_an_awaitable # type: ignore[misc] await not_an_awaitable._coro # coroutine 'ClientSession._request' was never awaited await server.close() async def test_aiohttp_request_ssl( aiohttp_server: AiohttpServer, ssl_ctx: ssl.SSLContext, client_ssl_ctx: ssl.SSLContext, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app, ssl=ssl_ctx) async with aiohttp.request("GET", server.make_url("/"), ssl=client_ssl_ctx) as resp: assert resp.status == 200 async def test_yield_from_in_session_request(aiohttp_client: AiohttpClient) -> None: # a test for backward compatibility with yield from syntax async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 async def test_session_auth( headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: client = await headers_echo_client(auth=aiohttp.BasicAuth("login", "pass")) async with client.get("/") as r: assert r.status == 200 content = await r.json() assert content["headers"]["Authorization"] == "Basic bG9naW46cGFzcw==" async def test_session_auth_override( headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: client = await headers_echo_client(auth=aiohttp.BasicAuth("login", "pass")) async with client.get("/", auth=aiohttp.BasicAuth("other_login", "pass")) as r: assert r.status == 200 content = await r.json() val = content["headers"]["Authorization"] assert val == "Basic b3RoZXJfbG9naW46cGFzcw==" async def test_session_auth_header_conflict(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, auth=aiohttp.BasicAuth("login", "pass")) headers = {"Authorization": "Basic b3RoZXJfbG9naW46cGFzcw=="} with pytest.raises(ValueError): await client.get("/", headers=headers) @pytest.mark.usefixtures("netrc_default_contents") async def test_netrc_auth_from_env( # type: ignore[misc] headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: """Test that netrc authentication works when NETRC env var is set and trust_env=True.""" client = await headers_echo_client(trust_env=True) async with client.get("/") as r: assert r.status == 200 content = await r.json() # Base64 encoded "netrc_user:netrc_pass" is "bmV0cmNfdXNlcjpuZXRyY19wYXNz" assert content["headers"]["Authorization"] == "Basic bmV0cmNfdXNlcjpuZXRyY19wYXNz" @pytest.mark.usefixtures("no_netrc") async def test_netrc_auth_skipped_without_netrc_file( # type: ignore[misc] headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: """Test that netrc authentication is skipped when no netrc file exists.""" client = await headers_echo_client(trust_env=True) async with client.get("/") as r: assert r.status == 200 content = await r.json() # No Authorization header should be present assert "Authorization" not in content["headers"] @pytest.mark.usefixtures("netrc_home_directory") async def test_netrc_auth_from_home_directory( # type: ignore[misc] headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: """Test that netrc authentication works from default ~/.netrc without NETRC env var.""" client = await headers_echo_client(trust_env=True) async with client.get("/") as r: assert r.status == 200 content = await r.json() assert content["headers"]["Authorization"] == "Basic bmV0cmNfdXNlcjpuZXRyY19wYXNz" @pytest.mark.usefixtures("netrc_default_contents") async def test_netrc_auth_overridden_by_explicit_auth( # type: ignore[misc] headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: """Test that explicit auth parameter overrides netrc authentication.""" client = await headers_echo_client(trust_env=True) # Make request with explicit auth (should override netrc) async with client.get( "/", auth=aiohttp.BasicAuth("explicit_user", "explicit_pass") ) as r: assert r.status == 200 content = await r.json() # Base64 encoded "explicit_user:explicit_pass" is "ZXhwbGljaXRfdXNlcjpleHBsaWNpdF9wYXNz" assert ( content["headers"]["Authorization"] == "Basic ZXhwbGljaXRfdXNlcjpleHBsaWNpdF9wYXNz" ) async def test_session_headers( headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: client = await headers_echo_client(headers={"X-Real-IP": "192.168.0.1"}) async with client.get("/") as r: assert r.status == 200 content = await r.json() assert content["headers"]["X-Real-IP"] == "192.168.0.1" async def test_session_headers_merge( headers_echo_client: Callable[ ..., Awaitable[TestClient[web.Request, web.Application]] ], ) -> None: client = await headers_echo_client( headers=[("X-Real-IP", "192.168.0.1"), ("X-Sent-By", "requests")] ) async with client.get("/", headers={"X-Sent-By": "aiohttp"}) as r: assert r.status == 200 content = await r.json() assert content["headers"]["X-Real-IP"] == "192.168.0.1" assert content["headers"]["X-Sent-By"] == "aiohttp" async def test_multidict_headers(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert await request.read() == data return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) data = b"sample data" async with client.post( "/", data=data, headers=MultiDict({"Content-Length": str(len(data))}) ) as r: assert r.status == 200 async def test_request_conn_closed(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.transport is not None request.transport.close() return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) with pytest.raises(aiohttp.ServerDisconnectedError) as excinfo: async with client.get("/") as resp: await resp.read() assert str(excinfo.value) != "" async def test_dont_close_explicit_connector(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as r: await r.read() assert client.session.connector is not None assert 1 == len(client.session.connector._conns) async def test_server_close_keepalive_connection() -> None: loop = asyncio.get_event_loop() class Proto(asyncio.Protocol): def connection_made(self, transport: asyncio.BaseTransport) -> None: assert isinstance(transport, asyncio.Transport) self.transp: asyncio.Transport | None = transport self.data = b"" def data_received(self, data: bytes) -> None: self.data += data assert data.endswith(b"\r\n\r\n") assert self.transp is not None self.transp.write( b"HTTP/1.1 200 OK\r\n" b"CONTENT-LENGTH: 2\r\n" b"CONNECTION: close\r\n" b"\r\n" b"ok" ) self.transp.close() def connection_lost(self, exc: BaseException | None) -> None: self.transp = None server = await loop.create_server(Proto, "127.0.0.1", unused_port()) addr = server.sockets[0].getsockname() connector = aiohttp.TCPConnector(limit=1) async with aiohttp.ClientSession(connector=connector) as session: url = "http://{}:{}/".format(*addr) for i in range(2): r = await session.request("GET", url) await r.read() assert 0 == len(connector._conns) await connector.close() server.close() await server.wait_closed() async def test_handle_keepalive_on_closed_connection() -> None: loop = asyncio.get_event_loop() class Proto(asyncio.Protocol): def connection_made(self, transport: asyncio.BaseTransport) -> None: assert isinstance(transport, asyncio.Transport) self.transp: asyncio.Transport | None = transport self.data = b"" def data_received(self, data: bytes) -> None: self.data += data assert data.endswith(b"\r\n\r\n") assert self.transp is not None self.transp.write(b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 2\r\n\r\nok") self.transp.close() def connection_lost(self, exc: BaseException | None) -> None: self.transp = None server = await loop.create_server(Proto, "127.0.0.1", unused_port()) addr = server.sockets[0].getsockname() async with aiohttp.TCPConnector(limit=1) as connector: async with aiohttp.ClientSession(connector=connector) as session: url = "http://{}:{}/".format(*addr) r = await session.request("GET", url) await r.read() assert 1 == len(connector._conns) closed_conn = next(iter(connector._conns.values())) await session.request("GET", url) assert 1 == len(connector._conns) new_conn = next(iter(connector._conns.values())) assert closed_conn is not new_conn server.close() await server.wait_closed() async def test_error_in_performing_request( ssl_ctx: ssl.SSLContext, aiohttp_client: AiohttpClient, aiohttp_server: AiohttpServer, ) -> None: async def handler(request: web.Request) -> NoReturn: assert False def exception_handler(loop: object, context: object) -> None: """Skip log messages about destroyed but pending tasks""" loop = asyncio.get_event_loop() loop.set_exception_handler(exception_handler) app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app, ssl=ssl_ctx) conn = aiohttp.TCPConnector(limit=1) client = await aiohttp_client(server, connector=conn) with pytest.raises(aiohttp.ClientConnectionError): await client.get("/") # second try should not hang with pytest.raises(aiohttp.ClientConnectionError): await client.get("/") async def test_await_after_cancelling(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) fut1 = loop.create_future() fut2 = loop.create_future() async def fetch1() -> None: async with client.get("/") as resp: assert resp.status == 200 fut1.set_result(None) with pytest.raises(asyncio.CancelledError): await fut2 async def fetch2() -> None: await fut1 async with client.get("/") as resp: assert resp.status == 200 async def canceller() -> None: await fut1 fut2.cancel() await asyncio.gather(fetch1(), fetch2(), canceller()) async def test_async_payload_generator(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.read() assert data == b"1234567890" * 100 return web.Response() app = web.Application() app.add_routes([web.post("/", handler)]) client = await aiohttp_client(app) async def gen() -> AsyncIterator[bytes]: for i in range(100): yield b"1234567890" async with client.post("/", data=gen()) as resp: assert resp.status == 200 async def test_read_from_closed_response(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"data") app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 with pytest.raises(aiohttp.ClientConnectionError): await resp.read() async def test_read_from_closed_response2(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"data") app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 await resp.read() with pytest.raises(aiohttp.ClientConnectionError): await resp.read() async def test_json_from_closed_response(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.json_response(42) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 await resp.read() # Should not allow reading outside of resp context even when body is available. with pytest.raises(aiohttp.ClientConnectionError): await resp.json() async def test_text_from_closed_response(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text="data") app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 await resp.read() # Should not allow reading outside of resp context even when body is available. with pytest.raises(aiohttp.ClientConnectionError): await resp.text() async def test_read_after_catch_raise_for_status(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"data", status=404) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: with pytest.raises(ClientResponseError, match="404"): # Should not release response when in async with context. resp.raise_for_status() result = await resp.read() assert result == b"data" async def test_read_after_raise_outside_context(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"data", status=404) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) resp = await client.get("/") with pytest.raises(ClientResponseError, match="404"): # No async with, so should release and therefore read() will fail. resp.raise_for_status() with pytest.raises(aiohttp.ClientConnectionError, match=r"^Connection closed$"): await resp.read() async def test_read_from_closed_content(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"data") app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 with pytest.raises(aiohttp.ClientConnectionError): await resp.content.readline() async def test_read_timeout(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: await asyncio.sleep(5) assert False app = web.Application() app.add_routes([web.get("/", handler)]) timeout = aiohttp.ClientTimeout(sock_read=0.1) client = await aiohttp_client(app, timeout=timeout) with pytest.raises(aiohttp.ServerTimeoutError): await client.get("/") async def test_socket_timeout(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: await asyncio.sleep(5) assert False app = web.Application() app.add_routes([web.get("/", handler)]) timeout = aiohttp.ClientTimeout(sock_read=0.1) client = await aiohttp_client(app, timeout=timeout) with pytest.raises(SocketTimeoutError): await client.get("/") async def test_read_timeout_closes_connection(aiohttp_client: AiohttpClient) -> None: request_count = 0 async def handler(request: web.Request) -> web.Response: nonlocal request_count request_count += 1 if request_count < 3: await asyncio.sleep(0.5) return web.Response(body=f"request:{request_count}") app = web.Application() app.add_routes([web.get("/", handler)]) timeout = aiohttp.ClientTimeout(total=0.1) client = await aiohttp_client(app, timeout=timeout) with pytest.raises(asyncio.TimeoutError): await client.get("/") # Make sure its really closed assert client.session.connector is not None assert not client.session.connector._conns with pytest.raises(asyncio.TimeoutError): await client.get("/") # Make sure its really closed assert not client.session.connector._conns async with client.get("/") as result: assert await result.read() == b"request:3" # Make sure its not closed assert client.session.connector._conns async def test_read_timeout_on_prepared_response(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: resp = aiohttp.web.StreamResponse() await resp.prepare(request) await asyncio.sleep(5) assert False app = web.Application() app.add_routes([web.get("/", handler)]) timeout = aiohttp.ClientTimeout(sock_read=0.1) client = await aiohttp_client(app, timeout=timeout) with pytest.raises(aiohttp.ServerTimeoutError): async with client.get("/") as resp: await resp.read() async def test_timeout_with_full_buffer(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: """Server response that never ends and always has more data available.""" resp = web.StreamResponse() await resp.prepare(request) while True: await resp.write(b"1" * 1000) await asyncio.sleep(0.01) async def request(client: TestClient[web.Request, web.Application]) -> None: timeout = aiohttp.ClientTimeout(total=0.5) async with client.get("/", timeout=timeout) as resp: with pytest.raises(asyncio.TimeoutError): async for data in resp.content.iter_chunked(1): await asyncio.sleep(0.01) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) # wait_for() used just to ensure that a failing test doesn't hang. await asyncio.wait_for(request(client), 1) async def test_read_bufsize_session_default(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"1234567") app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app, read_bufsize=2) async with client.get("/") as resp: assert resp.content.get_read_buffer_limits() == (2, 4) async def test_read_bufsize_explicit(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"1234567") app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/", read_bufsize=4) as resp: assert resp.content.get_read_buffer_limits() == (4, 8) async def test_http_empty_data_text(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.read() ret = "ok" if data == b"" else "fail" resp = web.Response(text=ret) resp.headers["Content-Type"] = request.headers["Content-Type"] return resp app = web.Application() app.add_routes([web.post("/", handler)]) client = await aiohttp_client(app) async with client.post("/", data="") as resp: assert resp.status == 200 assert await resp.text() == "ok" assert resp.headers["Content-Type"] == "text/plain; charset=utf-8" async def test_max_field_size_session_default(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={"Custom": "x" * 8182}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.headers["Custom"] == "x" * 8182 async def test_max_field_size_session_default_fail( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={"Custom": "x" * 8191}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) with pytest.raises(aiohttp.ClientResponseError): await client.get("/") async def test_max_field_size_session_explicit(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={"Custom": "x" * 8192}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app, max_field_size=8200) async with client.get("/") as resp: assert resp.headers["Custom"] == "x" * 8192 async def test_max_headers_session_default(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={f"Custom-{i}": "x" for i in range(120)}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.headers["Custom-119"] == "x" async def test_max_headers_session_default_fail( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={f"Custom-{i}": "x" for i in range(129)}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) with pytest.raises(aiohttp.ClientResponseError): await client.get("/") async def test_max_headers_session_explicit(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={f"Custom-{i}": "x" for i in range(130)}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app, max_headers=140) async with client.get("/") as resp: assert resp.headers["Custom-129"] == "x" async def test_max_headers_request_explicit(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={f"Custom-{i}": "x" for i in range(130)}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/", max_headers=140) as resp: assert resp.headers["Custom-129"] == "x" async def test_max_field_size_request_explicit(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={"Custom": "x" * 8192}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/", max_field_size=8200) as resp: assert resp.headers["Custom"] == "x" * 8192 async def test_max_line_size_session_default(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(status=200, reason="x" * 8177) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.reason == "x" * 8177 async def test_max_line_size_session_default_fail( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(status=200, reason="x" * 8192) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) with pytest.raises(aiohttp.ClientResponseError): await client.get("/") async def test_max_line_size_session_explicit(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(status=200, reason="x" * 8197) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app, max_line_size=8210) async with client.get("/") as resp: assert resp.reason == "x" * 8197 async def test_max_line_size_request_explicit(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(status=200, reason="x" * 8197) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) async with client.get("/", max_line_size=8210) as resp: assert resp.reason == "x" * 8197 async def test_rejected_upload( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path ) -> None: async def ok_handler(request: web.Request) -> web.Response: return web.Response() async def not_ok_handler(request: web.Request) -> NoReturn: raise web.HTTPBadRequest() app = web.Application() app.router.add_get("/ok", ok_handler) app.router.add_post("/not_ok", not_ok_handler) client = await aiohttp_client(app) file_size_bytes = 1024 * 1024 file_path = tmp_path / "uploaded.txt" file_path.write_text("0" * file_size_bytes, encoding="utf8") with open(file_path, "rb") as file: data = {"file": file} async with client.post("/not_ok", data=data) as resp_not_ok: assert resp_not_ok.status == 400 async with client.get("/ok", timeout=aiohttp.ClientTimeout(total=1)) as resp_ok: assert resp_ok.status == 200 async def test_request_with_wrong_ssl_type(aiohttp_client: AiohttpClient) -> None: app = web.Application() session = await aiohttp_client(app) with pytest.raises(TypeError, match="ssl should be SSLContext, Fingerprint, .*"): await session.get("/", ssl=42) # type: ignore[arg-type] @pytest.mark.parametrize( ("value", "exc_type"), [(42, TypeError), ("InvalidUrl", InvalidURL)], ) async def test_request_with_wrong_proxy( aiohttp_client: AiohttpClient, value: int | str, exc_type: type[Exception] ) -> None: app = web.Application() session = await aiohttp_client(app) with pytest.raises(exc_type): await session.get("/", proxy=value) # type: ignore[arg-type] async def test_raise_for_status_is_none(aiohttp_client: AiohttpClient) -> None: async def handler(_: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) session = await aiohttp_client(app, raise_for_status=None) await session.get("/") async def test_header_too_large_error(aiohttp_client: AiohttpClient) -> None: """By default when not specifying `max_field_size` requests should fail with a 400 status code.""" async def handler(_: web.Request) -> web.Response: return web.Response(headers={"VeryLargeHeader": "x" * 10000}) app = web.Application() app.add_routes([web.get("/", handler)]) client = await aiohttp_client(app) with pytest.raises( aiohttp.ClientResponseError, match="Got more than 8190 bytes*" ) as exc_info: await client.get("/") assert exc_info.value.status == 400 async def test_exception_when_read_outside_of_session( aiohttp_server: AiohttpServer, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"1" * 1000000) app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as sess: resp = await sess.get(server.make_url("/")) with pytest.raises(RuntimeError, match="Connection closed"): await resp.read() async def test_content_length_limit_enforced(aiohttp_server: AiohttpServer) -> None: """Test that Content-Length header value limits the amount of data sent to the server.""" received_data = bytearray() async def handler(request: web.Request) -> web.Response: # Read all data from the request and store it data = await request.read() received_data.extend(data) return web.Response(text="OK") app = web.Application() app.router.add_post("/", handler) server = await aiohttp_server(app) # Create data larger than what we'll limit with Content-Length data = b"X" * 1000 # Only send 500 bytes even though data is 1000 bytes headers = {"Content-Length": "500"} async with aiohttp.ClientSession() as session: await session.post(server.make_url("/"), data=data, headers=headers) # Verify only 500 bytes (not the full 1000) were received by the server assert len(received_data) == 500 assert received_data == b"X" * 500 async def test_content_length_limit_with_multiple_reads( aiohttp_server: AiohttpServer, ) -> None: """Test that Content-Length header value limits multi read data properly.""" received_data = bytearray() async def handler(request: web.Request) -> web.Response: # Read all data from the request and store it data = await request.read() received_data.extend(data) return web.Response(text="OK") app = web.Application() app.router.add_post("/", handler) server = await aiohttp_server(app) # Create an async generator of data async def data_generator() -> AsyncIterator[bytes]: yield b"Chunk1" * 100 # 600 bytes yield b"Chunk2" * 100 # another 600 bytes # Limit to 800 bytes even though we'd generate 1200 bytes headers = {"Content-Length": "800"} async with aiohttp.ClientSession() as session: async with session.post( server.make_url("/"), data=data_generator(), headers=headers ) as resp: await resp.read() # Ensure response is fully read and connection cleaned up # Verify only 800 bytes (not the full 1200) were received by the server assert len(received_data) == 800 # First chunk fully sent (600 bytes) assert received_data.startswith(b"Chunk1" * 100) # The rest should be from the second chunk (the exact split might vary by implementation) assert b"Chunk2" in received_data # Some part of the second chunk was sent # 200 bytes from the second chunk assert len(received_data) - len(b"Chunk1" * 100) == 200 async def test_post_connection_cleanup_with_bytesio( aiohttp_client: AiohttpClient, ) -> None: """Test that connections are properly cleaned up when using BytesIO data.""" async def handler(request: web.Request) -> web.Response: return web.Response(body=b"") app = web.Application() app.router.add_post("/hello", handler) client = await aiohttp_client(app) # Test with direct bytes and BytesIO multiple times to ensure connection cleanup for _ in range(10): async with client.post( "/hello", data=b"x", headers={"Content-Length": "1"}, ) as response: response.raise_for_status() assert client._session.connector is not None assert len(client._session.connector._conns) == 1 x = io.BytesIO(b"x") async with client.post( "/hello", data=x, headers={"Content-Length": "1"}, ) as response: response.raise_for_status() assert len(client._session.connector._conns) == 1 async def test_post_connection_cleanup_with_file( aiohttp_client: AiohttpClient, here: pathlib.Path ) -> None: """Test that connections are properly cleaned up when using file data.""" async def handler(request: web.Request) -> web.Response: await request.read() return web.Response(body=b"") app = web.Application() app.router.add_post("/hello", handler) client = await aiohttp_client(app) test_file = here / "data.unknown_mime_type" # Test with direct bytes and file multiple times to ensure connection cleanup for _ in range(10): async with client.post( "/hello", data=b"xx", headers={"Content-Length": "2"}, ) as response: response.raise_for_status() assert client._session.connector is not None assert len(client._session.connector._conns) == 1 fh = await asyncio.get_running_loop().run_in_executor( None, open, test_file, "rb" ) async with client.post( "/hello", data=fh, headers={"Content-Length": str(test_file.stat().st_size)}, ) as response: response.raise_for_status() assert len(client._session.connector._conns) == 1 async def test_post_content_exception_connection_kept( aiohttp_client: AiohttpClient, ) -> None: """Test that connections are kept after content.set_exception() with POST.""" async def handler(request: web.Request) -> web.Response: await request.read() return web.Response( body=b"x" * 1000 ) # Larger response to ensure it's not pre-buffered app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) # POST request with body - connection should be closed after content exception resp = await client.post("/", data=b"request body") with pytest.raises(RuntimeError): async with resp: assert resp.status == 200 resp.content.set_exception(RuntimeError("Simulated error")) await resp.read() assert resp.closed # Wait for any pending operations to complete await resp.wait_for_close() assert client._session.connector is not None # Connection is kept because content.set_exception() is a client-side operation # that doesn't affect the underlying connection state assert len(client._session.connector._conns) == 1 async def test_network_error_connection_closed( aiohttp_client: AiohttpClient, ) -> None: """Test that connections are closed after network errors.""" async def handler(request: web.Request) -> NoReturn: # Read the request body await request.read() # Start sending response but close connection before completing response = web.StreamResponse() response.content_length = 1000 # Promise 1000 bytes await response.prepare(request) # Send partial data then force close the connection await response.write(b"x" * 100) # Only send 100 bytes # Force close the transport to simulate network error assert request.transport is not None request.transport.close() assert False, "Will not return" app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) # POST request that will fail due to network error with pytest.raises(aiohttp.ClientPayloadError): resp = await client.post("/", data=b"request body") async with resp: await resp.read() # This should fail # Give event loop a chance to process connection cleanup await asyncio.sleep(0) assert client._session.connector is not None # Connection should be closed due to network error assert len(client._session.connector._conns) == 0 async def test_client_side_network_error_connection_closed( aiohttp_client: AiohttpClient, ) -> None: """Test that connections are closed after client-side network errors.""" handler_done = asyncio.Event() async def handler(request: web.Request) -> NoReturn: # Read the request body await request.read() # Start sending a large response response = web.StreamResponse() response.content_length = 10000 # Promise 10KB await response.prepare(request) # Send some data await response.write(b"x" * 1000) # Keep the response open - we'll interrupt from client side await asyncio.wait_for(handler_done.wait(), timeout=5.0) assert False, "Will not return" app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) # POST request that will fail due to client-side network error with pytest.raises(aiohttp.ClientPayloadError): resp = await client.post("/", data=b"request body") async with resp: # Simulate client-side network error by closing the transport # This simulates connection reset, network failure, etc. assert resp.connection is not None assert resp.connection.protocol is not None assert resp.connection.protocol.transport is not None resp.connection.protocol.transport.close() # This should fail with connection error await resp.read() # Signal handler to finish handler_done.set() # Give event loop a chance to process connection cleanup await asyncio.sleep(0) assert client._session.connector is not None # Connection should be closed due to client-side network error assert len(client._session.connector._conns) == 0 async def test_empty_response_non_chunked(aiohttp_client: AiohttpClient) -> None: """Test non-chunked response with empty body.""" async def handler(request: web.Request) -> web.Response: # Return empty response with Content-Length: 0 return web.Response(body=b"", headers={"Content-Length": "0"}) app = web.Application() app.router.add_get("/empty", handler) client = await aiohttp_client(app) resp = await client.get("/empty") assert resp.status == 200 assert resp.headers.get("Content-Length") == "0" data = await resp.read() assert data == b"" resp.close() async def test_set_eof_on_empty_response(aiohttp_client: AiohttpClient) -> None: """Test that triggers set_eof() method.""" async def handler(request: web.Request) -> web.Response: # Return response that completes immediately return web.Response(status=204) # No Content app = web.Application() app.router.add_get("/no-content", handler) client = await aiohttp_client(app) resp = await client.get("/no-content") assert resp.status == 204 data = await resp.read() assert data == b"" resp.close() async def test_bytes_payload_redirect(aiohttp_client: AiohttpClient) -> None: """Test that BytesPayload can be reused across redirects.""" data_received = [] async def redirect_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("redirect", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("final", data)) return web.Response(text=f"Received: {data.decode()}") app = web.Application() app.router.add_post("/redirect", redirect_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) payload_data = b"test payload data" payload = BytesPayload(payload_data) resp = await client.post("/redirect", data=payload) assert resp.status == 200 text = await resp.text() assert text == "Received: test payload data" # Both endpoints should have received the data assert data_received == [("redirect", payload_data), ("final", payload_data)] async def test_string_payload_redirect(aiohttp_client: AiohttpClient) -> None: """Test that StringPayload can be reused across redirects.""" data_received = [] async def redirect_handler(request: web.Request) -> web.Response: data = await request.text() data_received.append(("redirect", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: data = await request.text() data_received.append(("final", data)) return web.Response(text=f"Received: {data}") app = web.Application() app.router.add_post("/redirect", redirect_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) payload_data = "test string payload" payload = StringPayload(payload_data) resp = await client.post("/redirect", data=payload) assert resp.status == 200 text = await resp.text() assert text == "Received: test string payload" # Both endpoints should have received the data assert data_received == [("redirect", payload_data), ("final", payload_data)] async def test_async_iterable_payload_redirect(aiohttp_client: AiohttpClient) -> None: """Test redirecting consumed AsyncIterablePayload raises an error.""" data_received = [] async def redirect_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("redirect", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("final", data)) return web.Response(text=f"Received: {data.decode()}") app = web.Application() app.router.add_post("/redirect", redirect_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) chunks = [b"chunk1", b"chunk2", b"chunk3"] async def async_gen() -> AsyncIterator[bytes]: for chunk in chunks: yield chunk payload = AsyncIterablePayload(async_gen()) with pytest.raises( aiohttp.ClientPayloadError, match="Cannot follow redirect with a consumed request body", ): await client.post("/redirect", data=payload) # Only the first endpoint should have received data. expected_data = b"".join(chunks) assert data_received == [("redirect", expected_data)] @pytest.mark.parametrize("status", (301, 302)) async def test_async_iterable_payload_redirect_non_post_301_302( aiohttp_client: AiohttpClient, status: int ) -> None: """Test consumed async iterable body raises on 301/302 for non-POST methods.""" data_received = [] async def redirect_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("redirect", data)) return web.Response(status=status, headers={"Location": "/final_destination"}) app = web.Application() app.router.add_put("/redirect", redirect_handler) client = await aiohttp_client(app) chunks = [b"chunk1", b"chunk2", b"chunk3"] async def async_gen() -> AsyncIterator[bytes]: for chunk in chunks: yield chunk payload = AsyncIterablePayload(async_gen()) with pytest.raises( aiohttp.ClientPayloadError, match="Cannot follow redirect with a consumed request body", ): await client.put("/redirect", data=payload) expected_data = b"".join(chunks) assert data_received == [("redirect", expected_data)] async def test_buffered_reader_payload_redirect(aiohttp_client: AiohttpClient) -> None: """Test that BufferedReaderPayload can be reused across redirects.""" data_received = [] async def redirect_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("redirect", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("final", data)) return web.Response(text=f"Received: {data.decode()}") app = web.Application() app.router.add_post("/redirect", redirect_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) payload_data = b"buffered reader payload" buffer = io.BufferedReader(io.BytesIO(payload_data)) payload = BufferedReaderPayload(buffer) resp = await client.post("/redirect", data=payload) assert resp.status == 200 text = await resp.text() assert text == "Received: buffered reader payload" # Both endpoints should have received the data assert data_received == [("redirect", payload_data), ("final", payload_data)] async def test_string_io_payload_redirect(aiohttp_client: AiohttpClient) -> None: """Test that StringIOPayload can be reused across redirects.""" data_received = [] async def redirect_handler(request: web.Request) -> web.Response: data = await request.text() data_received.append(("redirect", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: data = await request.text() data_received.append(("final", data)) return web.Response(text=f"Received: {data}") app = web.Application() app.router.add_post("/redirect", redirect_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) payload_data = "string io payload" string_io = io.StringIO(payload_data) payload = StringIOPayload(string_io) resp = await client.post("/redirect", data=payload) assert resp.status == 200 text = await resp.text() assert text == "Received: string io payload" # Both endpoints should have received the data assert data_received == [("redirect", payload_data), ("final", payload_data)] async def test_bytes_io_payload_redirect(aiohttp_client: AiohttpClient) -> None: """Test that BytesIOPayload can be reused across redirects.""" data_received = [] async def redirect_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("redirect", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("final", data)) return web.Response(text=f"Received: {data.decode()}") app = web.Application() app.router.add_post("/redirect", redirect_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) payload_data = b"bytes io payload" bytes_io = io.BytesIO(payload_data) payload = BytesIOPayload(bytes_io) resp = await client.post("/redirect", data=payload) assert resp.status == 200 text = await resp.text() assert text == "Received: bytes io payload" # Both endpoints should have received the data assert data_received == [("redirect", payload_data), ("final", payload_data)] async def test_multiple_redirects_with_bytes_payload( aiohttp_client: AiohttpClient, ) -> None: """Test BytesPayload with multiple redirects.""" data_received = [] async def redirect1_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("redirect1", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/redirect2") async def redirect2_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("redirect2", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("final", data)) return web.Response(text=f"Received after 2 redirects: {data.decode()}") app = web.Application() app.router.add_post("/redirect", redirect1_handler) app.router.add_post("/redirect2", redirect2_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) payload_data = b"multi-redirect-test" payload = BytesPayload(payload_data) resp = await client.post("/redirect", data=payload) assert resp.status == 200 text = await resp.text() assert text == f"Received after 2 redirects: {payload_data.decode()}" # All 3 endpoints should have received the same data assert data_received == [ ("redirect1", payload_data), ("redirect2", payload_data), ("final", payload_data), ] async def test_redirect_with_empty_payload(aiohttp_client: AiohttpClient) -> None: """Test redirects with empty payloads.""" data_received = [] async def redirect_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("redirect", data)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: data = await request.read() data_received.append(("final", data)) return web.Response(text="Done") app = web.Application() app.router.add_post("/redirect", redirect_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) # Test with empty BytesPayload payload = BytesPayload(b"") resp = await client.post("/redirect", data=payload) assert resp.status == 200 assert data_received == [("redirect", b""), ("final", b"")] async def test_redirect_preserves_content_type(aiohttp_client: AiohttpClient) -> None: """Test that content-type is preserved across redirects.""" content_types = [] async def redirect_handler(request: web.Request) -> web.Response: content_types.append(("redirect", request.content_type)) # Use 307 to preserve POST method raise web.HTTPTemporaryRedirect("/final_destination") async def final_handler(request: web.Request) -> web.Response: content_types.append(("final", request.content_type)) return web.Response(text="Done") app = web.Application() app.router.add_post("/redirect", redirect_handler) app.router.add_post("/final_destination", final_handler) client = await aiohttp_client(app) # StringPayload should set content-type with charset payload = StringPayload("test data") resp = await client.post("/redirect", data=payload) assert resp.status == 200 # Both requests should have the same content type assert len(content_types) == 2 assert content_types[0][1] == "text/plain" assert content_types[1][1] == "text/plain" class MockedBytesPayload(BytesPayload): """A BytesPayload that tracks whether close() was called.""" def __init__(self, data: bytes) -> None: super().__init__(data) self.close_called = False async def close(self) -> None: self.close_called = True await super().close() async def test_too_many_redirects_closes_payload(aiohttp_client: AiohttpClient) -> None: """Test that TooManyRedirects exception closes the request payload.""" async def redirect_handler(request: web.Request) -> web.Response: # Read the payload to simulate server processing await request.read() count = int(request.match_info.get("count", 0)) # Use 307 to preserve POST method return web.Response( status=307, headers={hdrs.LOCATION: f"/redirect/{count + 1}"} ) app = web.Application() app.router.add_post(r"/redirect/{count:\d+}", redirect_handler) client = await aiohttp_client(app) # Create a mocked payload to verify close() is called payload = MockedBytesPayload(b"test payload") with pytest.raises(TooManyRedirects): await client.post("/redirect/0", data=payload, max_redirects=2) assert ( payload.close_called ), "Payload.close() was not called when TooManyRedirects was raised" async def test_invalid_url_redirect_closes_payload( aiohttp_client: AiohttpClient, ) -> None: """Test that InvalidUrlRedirectClientError exception closes the request payload.""" async def redirect_handler(request: web.Request) -> web.Response: # Read the payload to simulate server processing await request.read() # Return an invalid URL that will cause ValueError in URL parsing # Using a URL with invalid port that's out of range return web.Response( status=307, headers={hdrs.LOCATION: "http://example.com:999999/path"} ) app = web.Application() app.router.add_post("/redirect", redirect_handler) client = await aiohttp_client(app) # Create a mocked payload to verify close() is called payload = MockedBytesPayload(b"test payload") with pytest.raises( InvalidUrlRedirectClientError, match="Server attempted redirecting to a location that does not look like a URL", ): await client.post("/redirect", data=payload) assert ( payload.close_called ), "Payload.close() was not called when InvalidUrlRedirectClientError was raised" async def test_non_http_redirect_closes_payload(aiohttp_client: AiohttpClient) -> None: """Test that NonHttpUrlRedirectClientError exception closes the request payload.""" async def redirect_handler(request: web.Request) -> web.Response: # Read the payload to simulate server processing await request.read() # Return a non-HTTP scheme URL return web.Response( status=307, headers={hdrs.LOCATION: "ftp://example.com/file"} ) app = web.Application() app.router.add_post("/redirect", redirect_handler) client = await aiohttp_client(app) # Create a mocked payload to verify close() is called payload = MockedBytesPayload(b"test payload") with pytest.raises(NonHttpUrlRedirectClientError): await client.post("/redirect", data=payload) assert ( payload.close_called ), "Payload.close() was not called when NonHttpUrlRedirectClientError was raised" async def test_invalid_redirect_origin_closes_payload( aiohttp_client: AiohttpClient, ) -> None: """Test that InvalidUrlRedirectClientError exception (invalid origin) closes the request payload.""" async def redirect_handler(request: web.Request) -> web.Response: # Read the payload to simulate server processing await request.read() # Return a URL that will fail origin() check - using a relative URL without host return web.Response(status=307, headers={hdrs.LOCATION: "http:///path"}) app = web.Application() app.router.add_post("/redirect", redirect_handler) client = await aiohttp_client(app) # Create a mocked payload to verify close() is called payload = MockedBytesPayload(b"test payload") with pytest.raises( InvalidUrlRedirectClientError, match="Invalid redirect URL origin" ): await client.post("/redirect", data=payload) assert ( payload.close_called ), "Payload.close() was not called when InvalidUrlRedirectClientError (invalid origin) was raised" async def test_amazon_like_cookie_scenario(aiohttp_client: AiohttpClient) -> None: """Test real-world cookie scenario similar to Amazon.""" class FakeResolver(AbstractResolver): def __init__(self, port: int): self._port = port async def resolve( self, host: str, port: int = 0, family: int = 0 ) -> list[ResolveResult]: if host in ("amazon.it", "www.amazon.it"): return [ { "hostname": host, "host": "127.0.0.1", "port": self._port, "family": socket.AF_INET, "proto": 0, "flags": 0, } ] assert False, f"Unexpected host: {host}" async def close(self) -> None: """Close the resolver if needed.""" async def handler(request: web.Request) -> web.Response: response = web.Response(text="Login successful") # Simulate Amazon-like cookies from the issue cookies = [ "session-id=146-7423990-7621939; Domain=.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/; " "Secure; HttpOnly", "session-id=147-8529641-8642103; Domain=.www.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/; HttpOnly", "session-id-time=2082758401l; Domain=.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/; Secure", "session-id-time=2082758402l; Domain=.www.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/", "ubid-acbit=257-7531983-5395266; Domain=.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/; Secure", 'x-acbit="KdvJzu8W@Fx6Jj3EuNFLuP0N7OtkuCfs"; Version=1; ' "Domain=.amazon.it; Path=/; Secure; HttpOnly", "at-acbit=Atza|IwEBIM-gLr8; Domain=.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/; " "Secure; HttpOnly", 'sess-at-acbit="4+6VzSJPHIFD/OqO264hFxIng8Y="; ' "Domain=.amazon.it; Expires=Mon, 31-May-3024 10:00:00 GMT; " "Path=/; Secure; HttpOnly", "lc-acbit=it_IT; Domain=.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/", "i18n-prefs=EUR; Domain=.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/", "av-profile=null; Domain=.amazon.it; " "Expires=Mon, 31-May-3024 10:00:00 GMT; Path=/; Secure", 'user-pref-token="Am81ywsJ69xObBnuJ2FbilVH0mg="; ' "Domain=.amazon.it; Path=/; Secure", ] for cookie in cookies: response.headers.add("Set-Cookie", cookie) return response app = web.Application() app.router.add_get("/", handler) # Get the test server server = await aiohttp_client(app) port = server.port # Create a new client session with our fake resolver resolver = FakeResolver(port) async with ( aiohttp.TCPConnector(resolver=resolver, force_close=True) as connector, aiohttp.ClientSession(connector=connector) as session, ): # Make request to www.amazon.it which will resolve to # 127.0.0.1:port. This allows cookies for both .amazon.it # and .www.amazon.it domains resp = await session.get(f"http://www.amazon.it:{port}/") # Check headers cookie_headers = resp.headers.getall("Set-Cookie") assert ( len(cookie_headers) == 12 ), f"Expected 12 headers, got {len(cookie_headers)}" # Check parsed cookies - SimpleCookie only keeps the last # cookie with each name. So we expect 10 unique cookie names # (not 12) expected_cookie_names = { "session-id", # Will only have one "session-id-time", # Will only have one "ubid-acbit", "x-acbit", "at-acbit", "sess-at-acbit", "lc-acbit", "i18n-prefs", "av-profile", "user-pref-token", } assert set(resp.cookies.keys()) == expected_cookie_names assert ( len(resp.cookies) == 10 ), f"Expected 10 cookies in SimpleCookie, got {len(resp.cookies)}" # The important part: verify the session's cookie jar has # all cookies. The cookie jar should have all 12 cookies, # not just 10 jar_cookies = list(session.cookie_jar) assert ( len(jar_cookies) == 12 ), f"Expected 12 cookies in jar, got {len(jar_cookies)}" # Verify we have both session-id cookies with different domains session_ids = [c for c in jar_cookies if c.key == "session-id"] assert ( len(session_ids) == 2 ), f"Expected 2 session-id cookies, got {len(session_ids)}" # Verify the domains are different session_id_domains = {c["domain"] for c in session_ids} assert session_id_domains == { "amazon.it", "www.amazon.it", }, f"Got domains: {session_id_domains}" # Verify we have both session-id-time cookies with different # domains session_id_times = [c for c in jar_cookies if c.key == "session-id-time"] assert ( len(session_id_times) == 2 ), f"Expected 2 session-id-time cookies, got {len(session_id_times)}" # Now test that the raw headers were properly preserved assert resp._raw_cookie_headers is not None assert ( len(resp._raw_cookie_headers) == 12 ), "All raw headers should be preserved" @pytest.mark.parametrize("status", (307, 308)) async def test_file_upload_307_308_redirect( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path, status: int ) -> None: """Test that file uploads work correctly with 307/308 redirects. This verifies that file payloads maintain correct Content-Length on redirect by properly handling the file position. """ received_bodies: list[bytes] = [] async def handler(request: web.Request) -> web.Response: # Store the body content body = await request.read() received_bodies.append(body) if str(request.url.path).endswith("/"): # Redirect URLs ending with / to remove the trailing slash return web.Response( status=status, headers={ "Location": str(request.url.with_path(request.url.path.rstrip("/"))) }, ) # Return success with the body size return web.json_response( { "received_size": len(body), "content_length": request.headers.get("Content-Length"), } ) app = web.Application() app.router.add_post("/upload/", handler) app.router.add_post("/upload", handler) client = await aiohttp_client(app) # Create a test file test_file = tmp_path / f"test_upload_{status}.txt" content = b"This is test file content for upload." await asyncio.to_thread(test_file.write_bytes, content) expected_size = len(content) # Upload file to URL with trailing slash (will trigger redirect) f = await asyncio.to_thread(open, test_file, "rb") try: async with client.post("/upload/", data=f) as resp: assert resp.status == 200 result = await resp.json() # The server should receive the full file content assert result["received_size"] == expected_size assert result["content_length"] == str(expected_size) # Both requests should have received the same content assert len(received_bodies) == 2 assert received_bodies[0] == content # First request assert received_bodies[1] == content # After redirect finally: await asyncio.to_thread(f.close) @pytest.mark.parametrize("status", [301, 302]) @pytest.mark.parametrize("method", ["PUT", "PATCH", "DELETE"]) async def test_file_upload_301_302_redirect_non_post( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path, status: int, method: str ) -> None: """Test that file uploads work correctly with 301/302 redirects for non-POST methods. Per RFC 9110, 301/302 redirects should preserve the method and body for non-POST requests. """ received_bodies: list[bytes] = [] async def handler(request: web.Request) -> web.Response: # Store the body content body = await request.read() received_bodies.append(body) if str(request.url.path).endswith("/"): # Redirect URLs ending with / to remove the trailing slash return web.Response( status=status, headers={ "Location": str(request.url.with_path(request.url.path.rstrip("/"))) }, ) # Return success with the body size return web.json_response( { "method": request.method, "received_size": len(body), "content_length": request.headers.get("Content-Length"), } ) app = web.Application() app.router.add_route(method, "/upload/", handler) app.router.add_route(method, "/upload", handler) client = await aiohttp_client(app) # Create a test file test_file = tmp_path / f"test_upload_{status}_{method.lower()}.txt" content = f"Test {method} file content for {status} redirect.".encode() await asyncio.to_thread(test_file.write_bytes, content) expected_size = len(content) # Upload file to URL with trailing slash (will trigger redirect) f = await asyncio.to_thread(open, test_file, "rb") try: async with client.request(method, "/upload/", data=f) as resp: assert resp.status == 200 result = await resp.json() # The server should receive the full file content after redirect assert result["method"] == method # Method should be preserved assert result["received_size"] == expected_size assert result["content_length"] == str(expected_size) # Both requests should have received the same content assert len(received_bodies) == 2 assert received_bodies[0] == content # First request assert received_bodies[1] == content # After redirect finally: await asyncio.to_thread(f.close) async def test_file_upload_307_302_redirect_chain( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path ) -> None: """Test that file uploads work correctly with 307->302->200 redirect chain. This verifies that: 1. 307 preserves POST method and file body 2. 302 changes POST to GET and drops the body 3. No body leaks to the final GET request """ received_requests: list[dict[str, Any]] = [] async def handler(request: web.Request) -> web.Response: # Store request details body = await request.read() received_requests.append( { "path": str(request.url.path), "method": request.method, "body_size": len(body), "content_length": request.headers.get("Content-Length"), } ) if request.url.path == "/upload307": # First redirect: 307 should preserve method and body return web.Response(status=307, headers={"Location": "/upload302"}) elif request.url.path == "/upload302": # Second redirect: 302 should change POST to GET return web.Response(status=302, headers={"Location": "/final"}) else: # Final destination return web.json_response( { "final_method": request.method, "final_body_size": len(body), "requests_received": len(received_requests), } ) app = web.Application() app.router.add_route("*", "/upload307", handler) app.router.add_route("*", "/upload302", handler) app.router.add_route("*", "/final", handler) client = await aiohttp_client(app) # Create a test file test_file = tmp_path / "test_redirect_chain.txt" content = b"Test file content that should not leak to GET request" await asyncio.to_thread(test_file.write_bytes, content) expected_size = len(content) # Upload file to URL that triggers 307->302->final redirect chain f = await asyncio.to_thread(open, test_file, "rb") try: async with client.post("/upload307", data=f) as resp: assert resp.status == 200 result = await resp.json() # Verify the redirect chain assert len(resp.history) == 2 assert resp.history[0].status == 307 assert resp.history[1].status == 302 # Verify final request is GET with no body assert result["final_method"] == "GET" assert result["final_body_size"] == 0 assert result["requests_received"] == 3 # Verify the request sequence assert len(received_requests) == 3 # First request (307): POST with full body assert received_requests[0]["path"] == "/upload307" assert received_requests[0]["method"] == "POST" assert received_requests[0]["body_size"] == expected_size assert received_requests[0]["content_length"] == str(expected_size) # Second request (302): POST with preserved body from 307 assert received_requests[1]["path"] == "/upload302" assert received_requests[1]["method"] == "POST" assert received_requests[1]["body_size"] == expected_size assert received_requests[1]["content_length"] == str(expected_size) # Third request (final): GET with no body (302 changed method and dropped body) assert received_requests[2]["path"] == "/final" assert received_requests[2]["method"] == "GET" assert received_requests[2]["body_size"] == 0 assert received_requests[2]["content_length"] is None finally: await asyncio.to_thread(f.close) async def test_stream_reader_total_raw_bytes(aiohttp_client: AiohttpClient) -> None: """Test whether StreamReader.total_raw_bytes returns the number of bytes downloaded""" source_data = b"@dKal^pH>1h|YW1:c2J$" * 4096 async def handler(request: web.Request) -> web.Response: response = web.Response(body=source_data) response.enable_compression() return response app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) # Check for decompressed data async with client.get( "/", headers={"Accept-Encoding": "gzip"}, auto_decompress=True ) as resp: assert resp.headers["Content-Encoding"] == "gzip" assert int(resp.headers["Content-Length"]) < len(source_data) data = await resp.content.read() assert len(data) == len(source_data) assert resp.content.total_raw_bytes == int(resp.headers["Content-Length"]) # Check for compressed data async with client.get( "/", headers={"Accept-Encoding": "gzip"}, auto_decompress=False ) as resp: assert resp.headers["Content-Encoding"] == "gzip" data = await resp.content.read() assert resp.content.total_raw_bytes == len(data) assert resp.content.total_raw_bytes == int(resp.headers["Content-Length"]) # Check for non-compressed data async with client.get( "/", headers={"Accept-Encoding": "identity"}, auto_decompress=True ) as resp: assert "Content-Encoding" not in resp.headers data = await resp.content.read() assert resp.content.total_raw_bytes == len(data) assert resp.content.total_raw_bytes == int(resp.headers["Content-Length"]) ================================================ FILE: tests/test_client_middleware.py ================================================ """Tests for client middleware.""" import json import socket from typing import NoReturn import pytest from aiohttp import ( ClientError, ClientHandlerType, ClientRequest, ClientResponse, ClientSession, ClientTimeout, TCPConnector, web, ) from aiohttp.abc import ResolveResult from aiohttp.client_middlewares import build_client_middlewares from aiohttp.client_proto import ResponseHandler from aiohttp.pytest_plugin import AiohttpServer from aiohttp.resolver import ThreadedResolver from aiohttp.tracing import Trace class BlockedByMiddleware(ClientError): """Custom exception for when middleware blocks a request.""" async def test_client_middleware_called(aiohttp_server: AiohttpServer) -> None: """Test that client middleware is called.""" middleware_called = False request_count = 0 async def handler(request: web.Request) -> web.Response: nonlocal request_count request_count += 1 return web.Response(text=f"OK {request_count}") async def test_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: nonlocal middleware_called middleware_called = True response = await handler(request) return response app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(test_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "OK 1" assert middleware_called is True assert request_count == 1 async def test_client_middleware_retry(aiohttp_server: AiohttpServer) -> None: """Test that middleware can trigger retries.""" request_count = 0 async def handler(request: web.Request) -> web.Response: nonlocal request_count request_count += 1 if request_count == 1: return web.Response(status=503) return web.Response(text=f"OK {request_count}") async def retry_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: response = None for _ in range(2): # pragma: no branch response = await handler(request) if response.ok: return response assert False, "not reachable in test" app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(retry_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "OK 2" assert request_count == 2 async def test_client_middleware_per_request(aiohttp_server: AiohttpServer) -> None: """Test that middleware can be specified per request.""" session_middleware_called = False request_middleware_called = False async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") async def session_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: nonlocal session_middleware_called session_middleware_called = True response = await handler(request) return response async def request_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: nonlocal request_middleware_called request_middleware_called = True response = await handler(request) return response app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) # Request with session middleware async with ClientSession(middlewares=(session_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 assert session_middleware_called is True assert request_middleware_called is False # Reset flags session_middleware_called = False # Request with override middleware async with ClientSession(middlewares=(session_middleware,)) as session: async with session.get( server.make_url("/"), middlewares=(request_middleware,) ) as resp: assert resp.status == 200 assert session_middleware_called is False assert request_middleware_called is True async def test_multiple_client_middlewares(aiohttp_server: AiohttpServer) -> None: """Test that multiple middlewares are executed in order.""" calls: list[str] = [] async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") async def middleware1( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: calls.append("before1") response = await handler(request) calls.append("after1") return response async def middleware2( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: calls.append("before2") response = await handler(request) calls.append("after2") return response app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(middleware1, middleware2)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 # Middlewares are applied in reverse order (like server middlewares) # So middleware1 wraps middleware2 assert calls == ["before1", "before2", "after2", "after1"] async def test_client_middleware_auth_example(aiohttp_server: AiohttpServer) -> None: """Test an authentication middleware example.""" async def handler(request: web.Request) -> web.Response: auth_header = request.headers.get("Authorization") if auth_header == "Bearer valid-token": return web.Response(text="Authenticated") return web.Response(status=401, text="Unauthorized") async def auth_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # Add authentication header before request request.headers["Authorization"] = "Bearer valid-token" response = await handler(request) return response app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) # Without middleware - should fail async with ClientSession() as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 401 # With middleware - should succeed async with ClientSession(middlewares=(auth_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Authenticated" async def test_client_middleware_challenge_auth(aiohttp_server: AiohttpServer) -> None: """Test authentication middleware with challenge/response pattern like digest auth.""" request_count = 0 challenge_token = "challenge-123" async def handler(request: web.Request) -> web.Response: nonlocal request_count request_count += 1 auth_header = request.headers.get("Authorization") # First request - no auth header, return challenge if request_count == 1 and not auth_header: return web.Response( status=401, headers={ "WWW-Authenticate": f'Custom realm="test", nonce="{challenge_token}"' }, ) # Subsequent requests - check for correct auth with challenge if auth_header == f'Custom response="{challenge_token}-secret"': return web.Response(text="Authenticated") assert False, "Should not reach here - invalid auth scenario" async def challenge_auth_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: nonce: str | None = None attempted: bool = False while True: # If we have challenge data from previous attempt, add auth header if nonce and attempted: request.headers["Authorization"] = f'Custom response="{nonce}-secret"' response = await handler(request) # If we get a 401 with challenge, store it and retry if response.status == 401 and not attempted: www_auth = response.headers.get("WWW-Authenticate") if www_auth and "nonce=" in www_auth: # Extract nonce from authentication header nonce_start = www_auth.find('nonce="') + 7 nonce_end = www_auth.find('"', nonce_start) nonce = www_auth[nonce_start:nonce_end] attempted = True continue else: assert False, "Should not reach here" return response app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(challenge_auth_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Authenticated" # Should have made 2 requests: initial and retry with auth assert request_count == 2 async def test_client_middleware_multi_step_auth(aiohttp_server: AiohttpServer) -> None: """Test middleware with multi-step authentication flow.""" auth_state: dict[str, int] = {} middleware_state: dict[str, int | str | None] = { "step": 0, "session": None, "challenge": None, } async def handler(request: web.Request) -> web.Response: client_id = request.headers.get("X-Client-ID", "unknown") auth_header = request.headers.get("Authorization") step = auth_state.get(client_id, 0) # Step 0: No auth, request client ID if step == 0 and not auth_header: auth_state[client_id] = 1 return web.Response( status=401, headers={"X-Auth-Step": "1", "X-Session": "session-123"} ) # Step 1: Has session, request credentials if step == 1 and auth_header == "Bearer session-123": auth_state[client_id] = 2 return web.Response( status=401, headers={"X-Auth-Step": "2", "X-Challenge": "challenge-456"} ) # Step 2: Has challenge response, authenticate if step == 2 and auth_header == "Bearer challenge-456-response": return web.Response(text="Authenticated") assert False, "Should not reach here - invalid multi-step auth flow" async def multi_step_auth_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: request.headers["X-Client-ID"] = "test-client" for _ in range(3): # Apply auth based on current state if middleware_state["step"] == 1 and middleware_state["session"]: request.headers["Authorization"] = ( f"Bearer {middleware_state['session']}" ) elif middleware_state["step"] == 2 and middleware_state["challenge"]: request.headers["Authorization"] = ( f"Bearer {middleware_state['challenge']}-response" ) response = await handler(request) # Handle multi-step auth flow if response.status == 401: auth_step = response.headers.get("X-Auth-Step") if auth_step == "1": # First step: store session token middleware_state["session"] = response.headers.get("X-Session") middleware_state["step"] = 1 continue elif auth_step == "2": # Second step: store challenge middleware_state["challenge"] = response.headers.get("X-Challenge") middleware_state["step"] = 2 continue else: assert False, "Should not reach here" return response # This should not be reached but keeps mypy happy assert False, "Should not reach here" app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(multi_step_auth_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Authenticated" async def test_client_middleware_conditional_retry( aiohttp_server: AiohttpServer, ) -> None: """Test middleware with conditional retry based on response content.""" request_count = 0 token_state: dict[str, str | bool] = { "token": "old-token", "refreshed": False, } async def handler(request: web.Request) -> web.Response: nonlocal request_count request_count += 1 auth_token = request.headers.get("X-Auth-Token") if request_count == 1: # First request returns expired token error return web.json_response( {"error": "token_expired", "refresh_required": True}, status=401 ) if auth_token == "refreshed-token": return web.json_response({"data": "success"}) assert False, "Should not reach here - invalid token refresh flow" async def token_refresh_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: for _ in range(2): # Add token to request request.headers["X-Auth-Token"] = str(token_state["token"]) response = await handler(request) # Check if token needs refresh if response.status == 401 and not token_state["refreshed"]: data = await response.json() if data.get("error") == "token_expired" and data.get( "refresh_required" ): # Simulate token refresh token_state["token"] = "refreshed-token" token_state["refreshed"] = True continue else: assert False, "Should not reach here" return response # This should not be reached but keeps mypy happy assert False, "Should not reach here" app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(token_refresh_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 data = await resp.json() assert data == {"data": "success"} assert request_count == 2 # Initial request + retry after refresh async def test_build_client_middlewares_empty() -> None: """Test build_client_middlewares with empty middlewares.""" async def handler(request: ClientRequest) -> NoReturn: """Dummy handler.""" assert False # Test empty case result = build_client_middlewares(handler, ()) assert result is handler # Should return handler unchanged async def test_client_middleware_class_based_auth( aiohttp_server: AiohttpServer, ) -> None: """Test middleware using class-based pattern with instance state.""" class TokenAuthMiddleware: """Middleware that handles token-based authentication.""" def __init__(self, token: str) -> None: self.token = token self.request_count = 0 async def __call__( self, request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: self.request_count += 1 request.headers["Authorization"] = f"Bearer {self.token}" return await handler(request) async def handler(request: web.Request) -> web.Response: auth_header = request.headers.get("Authorization") if auth_header == "Bearer test-token": return web.Response(text="Authenticated") assert False, "Should not reach here - class auth should always have token" app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) # Create middleware instance auth_middleware = TokenAuthMiddleware("test-token") async with ClientSession(middlewares=(auth_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Authenticated" # Verify the middleware was called assert auth_middleware.request_count == 1 async def test_client_middleware_stateful_retry(aiohttp_server: AiohttpServer) -> None: """Test retry middleware using class with state management.""" class RetryMiddleware: """Middleware that retries failed requests with backoff.""" def __init__(self, max_retries: int = 3) -> None: self.max_retries = max_retries async def __call__( self, request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: retry_count = 0 while True: response = await handler(request) if response.status >= 500 and retry_count < self.max_retries: retry_count += 1 continue return response request_count = 0 async def handler(request: web.Request) -> web.Response: nonlocal request_count request_count += 1 if request_count < 3: return web.Response(status=503) return web.Response(text="Success") app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) retry_middleware = RetryMiddleware(max_retries=2) async with ClientSession(middlewares=(retry_middleware,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Success" assert request_count == 3 # Initial + 2 retries async def test_client_middleware_multiple_instances( aiohttp_server: AiohttpServer, ) -> None: """Test using multiple instances of the same middleware class.""" class HeaderMiddleware: """Middleware that adds a header with instance-specific value.""" def __init__(self, header_name: str, header_value: str) -> None: self.header_name = header_name self.header_value = header_value self.applied = False async def __call__( self, request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: self.applied = True request.headers[self.header_name] = self.header_value return await handler(request) headers_received = {} async def handler(request: web.Request) -> web.Response: headers_received.update(dict(request.headers)) return web.Response(text="OK") app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) # Create two instances with different headers middleware1 = HeaderMiddleware("X-Custom-1", "value1") middleware2 = HeaderMiddleware("X-Custom-2", "value2") async with ClientSession(middlewares=(middleware1, middleware2)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 # Both middlewares should have been applied assert middleware1.applied is True assert middleware2.applied is True assert headers_received.get("X-Custom-1") == "value1" assert headers_received.get("X-Custom-2") == "value2" async def test_request_middleware_overrides_session_middleware_with_empty( aiohttp_server: AiohttpServer, ) -> None: """Test that passing empty middlewares tuple to a request disables session-level middlewares.""" session_middleware_called = False async def handler(request: web.Request) -> web.Response: auth_header = request.headers.get("Authorization") if auth_header: return web.Response(text=f"Auth: {auth_header}") return web.Response(text="No auth") async def session_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: nonlocal session_middleware_called session_middleware_called = True request.headers["Authorization"] = "Bearer session-token" response = await handler(request) return response app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) # Create session with middleware async with ClientSession(middlewares=(session_middleware,)) as session: # First request uses session middleware async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Auth: Bearer session-token" assert session_middleware_called is True # Reset flags session_middleware_called = False # Second request explicitly disables middlewares with empty tuple async with session.get(server.make_url("/"), middlewares=()) as resp: assert resp.status == 200 text = await resp.text() assert text == "No auth" assert session_middleware_called is False async def test_request_middleware_overrides_session_middleware_with_specific( aiohttp_server: AiohttpServer, ) -> None: """Test that passing specific middlewares to a request overrides session-level middlewares.""" session_middleware_called = False request_middleware_called = False async def handler(request: web.Request) -> web.Response: auth_header = request.headers["Authorization"] return web.Response(text=f"Auth: {auth_header}") async def session_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: nonlocal session_middleware_called session_middleware_called = True request.headers["Authorization"] = "Bearer session-token" response = await handler(request) return response async def request_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: nonlocal request_middleware_called request_middleware_called = True request.headers["Authorization"] = "Bearer request-token" response = await handler(request) return response app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) # Create session with middleware async with ClientSession(middlewares=(session_middleware,)) as session: # First request uses session middleware async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Auth: Bearer session-token" assert session_middleware_called is True assert request_middleware_called is False # Reset flags session_middleware_called = False request_middleware_called = False # Second request uses request-specific middleware async with session.get( server.make_url("/"), middlewares=(request_middleware,) ) as resp: assert resp.status == 200 text = await resp.text() assert text == "Auth: Bearer request-token" assert session_middleware_called is False assert request_middleware_called is True @pytest.mark.parametrize( "exception_class,match_text", [ (ValueError, "Middleware error"), (ClientError, "Client error from middleware"), (OSError, "OS error from middleware"), ], ) async def test_client_middleware_exception_closes_connection( aiohttp_server: AiohttpServer, exception_class: type[Exception], match_text: str, ) -> None: """Test that connections are closed when middleware raises an exception.""" async def handler(request: web.Request) -> NoReturn: assert False, "Handler should not be reached" async def failing_middleware( request: ClientRequest, handler: ClientHandlerType ) -> NoReturn: # Raise exception before the handler is called raise exception_class(match_text) app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) # Create custom connector connector = TCPConnector() async with ClientSession( connector=connector, middlewares=(failing_middleware,) ) as session: # Make request that should fail in middleware with pytest.raises(exception_class, match=match_text): await session.get(server.make_url("/")) # Check that the connector has no active connections # If connections were properly closed, _conns should be empty assert len(connector._conns) == 0 await connector.close() async def test_client_middleware_blocks_connection_before_established( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can block connections before they are established.""" blocked_hosts = {"blocked.example.com", "evil.com"} connection_attempts: list[str] = [] async def handler(request: web.Request) -> web.Response: return web.Response(text="Reached") async def blocking_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # Record the connection attempt connection_attempts.append(str(request.url)) # Block requests to certain hosts if request.url.host in blocked_hosts: raise BlockedByMiddleware(f"Connection to {request.url.host} is blocked") # Allow the request to proceed return await handler(request) app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) connector = TCPConnector() async with ClientSession( connector=connector, middlewares=(blocking_middleware,) ) as session: # Test allowed request allowed_url = server.make_url("/") async with session.get(allowed_url) as resp: assert resp.status == 200 assert await resp.text() == "Reached" # Test blocked request with pytest.raises(BlockedByMiddleware) as exc_info: # Use a fake URL that would fail DNS if connection was attempted await session.get("https://blocked.example.com/") assert "Connection to blocked.example.com is blocked" in str(exc_info.value) # Test another blocked host with pytest.raises(BlockedByMiddleware) as exc_info: await session.get("https://evil.com/path") assert "Connection to evil.com is blocked" in str(exc_info.value) # Verify that connections were attempted in the correct order assert len(connection_attempts) == 3 assert allowed_url.host assert connection_attempts == [ str(server.make_url("/")), "https://blocked.example.com/", "https://evil.com/path", ] # Check that no connections were leaked assert len(connector._conns) == 0 await connector.close() async def test_client_middleware_blocks_connection_without_dns_lookup( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware prevents DNS lookups for blocked hosts.""" blocked_hosts = {"blocked.domain.tld"} dns_lookups_made: list[str] = [] # Create a simple server for the allowed request async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) class TrackingResolver(ThreadedResolver): async def resolve( self, hostname: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, ) -> list[ResolveResult]: dns_lookups_made.append(hostname) return await super().resolve(hostname, port, family) async def blocking_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # Block requests to certain hosts before DNS lookup if request.url.host in blocked_hosts: raise BlockedByMiddleware(f"Blocked by policy: {request.url.host}") return await handler(request) resolver = TrackingResolver() connector = TCPConnector(resolver=resolver) async with ClientSession( connector=connector, middlewares=(blocking_middleware,) ) as session: # Test blocked request to non-existent domain with pytest.raises(BlockedByMiddleware) as exc_info: await session.get("https://blocked.domain.tld/") assert "Blocked by policy: blocked.domain.tld" in str(exc_info.value) # Verify that no DNS lookup was made for the blocked domain assert "blocked.domain.tld" not in dns_lookups_made # Test allowed request to existing server - this should trigger DNS lookup async with session.get(f"http://localhost:{server.port}") as resp: assert resp.status == 200 # Verify that DNS lookup was made for the allowed request # The server might use a hostname that requires DNS resolution assert len(dns_lookups_made) > 0 # Make sure blocked domain is still not in DNS lookups assert "blocked.domain.tld" not in dns_lookups_made # Clean up await connector.close() async def test_client_middleware_retry_reuses_connection( aiohttp_server: AiohttpServer, ) -> None: """Test that connections are reused when middleware performs retries.""" request_count = 0 async def handler(request: web.Request) -> web.Response: nonlocal request_count request_count += 1 if request_count == 1: return web.Response(status=400) # First request returns 400 with no body return web.Response(text="OK") class TrackingConnector(TCPConnector): """Connector that tracks connection attempts.""" connection_attempts = 0 async def _create_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: self.connection_attempts += 1 return await super()._create_connection(req, traces, timeout) class RetryOnceMiddleware: """Middleware that retries exactly once.""" def __init__(self) -> None: self.attempt_count = 0 async def __call__( self, request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: retry_count = 0 while True: self.attempt_count += 1 response = await handler(request) if response.status == 400 and retry_count == 0: retry_count += 1 continue return response app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app) connector = TrackingConnector() middleware = RetryOnceMiddleware() async with ClientSession(connector=connector, middlewares=(middleware,)) as session: # Make initial request async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "OK" # Should have made 2 request attempts (initial + 1 retry) assert middleware.attempt_count == 2 # Should have created only 1 connection (reused on retry) assert connector.connection_attempts == 1 await connector.close() async def test_middleware_uses_session_avoids_recursion_with_path_check( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can avoid infinite recursion using a path check.""" log_collector: list[dict[str, str]] = [] async def log_api_handler(request: web.Request) -> web.Response: """Handle log API requests.""" data: dict[str, str] = await request.json() log_collector.append(data) return web.Response(text="OK") async def main_handler(request: web.Request) -> web.Response: """Handle main server requests.""" return web.Response(text=f"Hello from {request.path}") # Create log API server log_app = web.Application() log_app.router.add_post("/log", log_api_handler) log_server = await aiohttp_server(log_app) # Create main server main_app = web.Application() main_app.router.add_get("/{path:.*}", main_handler) main_server = await aiohttp_server(main_app) async def log_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: """Log requests to external API, avoiding recursion with path check.""" # Avoid infinite recursion by not logging requests to the /log endpoint if request.url.path != "/log": # Use the session from the request to make the logging call async with request.session.post( f"http://localhost:{log_server.port}/log", json={"method": str(request.method), "url": str(request.url)}, ) as resp: assert resp.status == 200 return await handler(request) # Create session with the middleware async with ClientSession(middlewares=(log_middleware,)) as session: # Make request to main server - should be logged async with session.get(main_server.make_url("/test")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Hello from /test" # Make direct request to log API - should NOT be logged (avoid recursion) async with session.post( log_server.make_url("/log"), json={"method": "DIRECT_POST", "url": "manual_test_entry"}, ) as resp: assert resp.status == 200 # Check logs # The first request should be logged # The second request (to /log) should also be logged but not the middleware's own log request assert len(log_collector) == 2 assert log_collector[0]["method"] == "GET" assert log_collector[0]["url"] == str(main_server.make_url("/test")) assert log_collector[1]["method"] == "DIRECT_POST" assert log_collector[1]["url"] == "manual_test_entry" async def test_middleware_uses_session_avoids_recursion_with_disabled_middleware( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can avoid infinite recursion by disabling middleware.""" log_collector: list[dict[str, str]] = [] request_count = 0 async def log_api_handler(request: web.Request) -> web.Response: """Handle log API requests.""" nonlocal request_count request_count += 1 data: dict[str, str] = await request.json() log_collector.append(data) return web.Response(text="OK") async def main_handler(request: web.Request) -> web.Response: """Handle main server requests.""" return web.Response(text=f"Hello from {request.path}") # Create log API server log_app = web.Application() log_app.router.add_post("/log", log_api_handler) log_server = await aiohttp_server(log_app) # Create main server main_app = web.Application() main_app.router.add_get("/{path:.*}", main_handler) main_server = await aiohttp_server(main_app) async def log_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: """Log all requests using session with disabled middleware.""" # Use the session from the request to make the logging call # Disable middleware to avoid infinite recursion async with request.session.post( f"http://localhost:{log_server.port}/log", json={"method": str(request.method), "url": str(request.url)}, middlewares=(), # This prevents infinite recursion ) as resp: assert resp.status == 200 return await handler(request) # Create session with the middleware async with ClientSession(middlewares=(log_middleware,)) as session: # Make request to main server - should be logged async with session.get(main_server.make_url("/test")) as resp: assert resp.status == 200 text = await resp.text() assert text == "Hello from /test" # Make another request - should also be logged async with session.get(main_server.make_url("/another")) as resp: assert resp.status == 200 # Check logs - both requests should be logged assert len(log_collector) == 2 assert log_collector[0]["method"] == "GET" assert log_collector[0]["url"] == str(main_server.make_url("/test")) assert log_collector[1]["method"] == "GET" assert log_collector[1]["url"] == str(main_server.make_url("/another")) # Ensure that log requests were made without the middleware # (request_count equals number of logged requests, not infinite) assert request_count == 2 async def test_middleware_can_check_request_body( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can check request body.""" received_bodies: list[str] = [] received_headers: list[dict[str, str]] = [] async def handler(request: web.Request) -> web.Response: """Server handler that receives requests.""" body = await request.text() received_bodies.append(body) received_headers.append(dict(request.headers)) return web.Response(text="OK") app = web.Application() app.router.add_post("/api", handler) app.router.add_get("/api", handler) # Add GET handler too server = await aiohttp_server(app) class CustomAuth: """Middleware that follows the GitHub discussion pattern for authentication.""" def __init__(self, secretkey: str) -> None: self.secretkey = secretkey def get_hash(self, request: ClientRequest) -> str: data = request.body.decode("utf-8") or "{}" # Simulate authentication hash without using real crypto return f"SIGNATURE-{self.secretkey}-{len(data)}-{data[:10]}" async def __call__( self, request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: request.headers["CUSTOM-AUTH"] = self.get_hash(request) return await handler(request) middleware = CustomAuth("test-secret-key") async with ClientSession(middlewares=(middleware,)) as session: # Test 1: Send JSON data with user/action data1 = {"user": "alice", "action": "login"} json_str1 = json.dumps(data1) async with session.post( server.make_url("/api"), data=json_str1, headers={"Content-Type": "application/json"}, ) as resp: assert resp.status == 200 # Test 2: Send JSON data with different fields data2 = {"user": "bob", "value": 42} json_str2 = json.dumps(data2) async with session.post( server.make_url("/api"), data=json_str2, headers={"Content-Type": "application/json"}, ) as resp: assert resp.status == 200 # Test 3: Send GET request with no body async with session.get(server.make_url("/api")) as resp: assert resp.status == 200 # GET with empty body still should validate # Test 4: Send plain text (non-JSON) text_data = "plain text body" async with session.post( server.make_url("/api"), data=text_data, headers={"Content-Type": "text/plain"}, ) as resp: assert resp.status == 200 # Verify server received the correct headers with authentication headers1 = received_headers[0] assert ( headers1["CUSTOM-AUTH"] == f"SIGNATURE-test-secret-key-{len(json_str1)}-{json_str1[:10]}" ) headers2 = received_headers[1] assert ( headers2["CUSTOM-AUTH"] == f"SIGNATURE-test-secret-key-{len(json_str2)}-{json_str2[:10]}" ) headers3 = received_headers[2] # GET request with no body should have empty JSON body assert headers3["CUSTOM-AUTH"] == "SIGNATURE-test-secret-key-2-{}" headers4 = received_headers[3] assert ( headers4["CUSTOM-AUTH"] == f"SIGNATURE-test-secret-key-{len(text_data)}-{text_data[:10]}" ) # Verify all responses were successful assert received_bodies[0] == json_str1 assert received_bodies[1] == json_str2 assert received_bodies[2] == "" # GET request has no body assert received_bodies[3] == text_data async def test_client_middleware_update_shorter_body( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can update request body using update_body method.""" async def handler(request: web.Request) -> web.Response: body = await request.text() return web.Response(text=body) app = web.Application() app.router.add_post("/", handler) server = await aiohttp_server(app) async def update_body_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # Update the request body await request.update_body(b"short body") return await handler(request) async with ClientSession(middlewares=(update_body_middleware,)) as session: async with session.post(server.make_url("/"), data=b"original body") as resp: assert resp.status == 200 text = await resp.text() assert text == "short body" async def test_client_middleware_update_longer_body( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can update request body using update_body method.""" async def handler(request: web.Request) -> web.Response: body = await request.text() return web.Response(text=body) app = web.Application() app.router.add_post("/", handler) server = await aiohttp_server(app) async def update_body_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # Update the request body await request.update_body(b"much much longer body") return await handler(request) async with ClientSession(middlewares=(update_body_middleware,)) as session: async with session.post(server.make_url("/"), data=b"original body") as resp: assert resp.status == 200 text = await resp.text() assert text == "much much longer body" async def test_client_middleware_update_string_body( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can update request body using update_body method.""" async def handler(request: web.Request) -> web.Response: body = await request.text() return web.Response(text=body) app = web.Application() app.router.add_post("/", handler) server = await aiohttp_server(app) async def update_body_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # Update the request body await request.update_body("this is a string") return await handler(request) async with ClientSession(middlewares=(update_body_middleware,)) as session: async with session.post(server.make_url("/"), data="original string") as resp: assert resp.status == 200 text = await resp.text() assert text == "this is a string" async def test_client_middleware_switch_types( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can update request body using update_body method.""" async def handler(request: web.Request) -> web.Response: body = await request.text() return web.Response(text=body) app = web.Application() app.router.add_post("/", handler) server = await aiohttp_server(app) async def update_body_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: # Update the request body await request.update_body("now a string") return await handler(request) async with ClientSession(middlewares=(update_body_middleware,)) as session: async with session.post(server.make_url("/"), data=b"original bytes") as resp: assert resp.status == 200 text = await resp.text() assert text == "now a string" ================================================ FILE: tests/test_client_middleware_digest_auth.py ================================================ """Test digest authentication middleware for aiohttp client.""" import io import re import time from collections.abc import Generator from hashlib import md5, sha1 from typing import Literal from unittest import mock import pytest from yarl import URL from aiohttp import ClientSession, hdrs from aiohttp.client_exceptions import ClientError from aiohttp.client_middleware_digest_auth import ( _HEADER_PAIRS_PATTERN, DigestAuthChallenge, DigestAuthMiddleware, DigestFunctions, escape_quotes, parse_header_pairs, unescape_quotes, ) from aiohttp.client_reqrep import ClientResponse from aiohttp.payload import BytesIOPayload from aiohttp.pytest_plugin import AiohttpServer from aiohttp.web import Application, Request, Response @pytest.fixture def digest_auth_mw() -> DigestAuthMiddleware: return DigestAuthMiddleware("user", "pass") @pytest.fixture def basic_challenge() -> DigestAuthChallenge: """Return a basic digest auth challenge with required fields only.""" return DigestAuthChallenge(realm="test", nonce="abc") @pytest.fixture def complete_challenge() -> DigestAuthChallenge: """Return a complete digest auth challenge with all fields.""" return DigestAuthChallenge( realm="test", nonce="abc", qop="auth", algorithm="MD5", opaque="xyz" ) @pytest.fixture def qop_challenge() -> DigestAuthChallenge: """Return a digest auth challenge with qop field.""" return DigestAuthChallenge(realm="test", nonce="abc", qop="auth") @pytest.fixture def no_qop_challenge() -> DigestAuthChallenge: """Return a digest auth challenge without qop.""" return DigestAuthChallenge(realm="test-realm", nonce="testnonce", algorithm="MD5") @pytest.fixture def auth_mw_with_challenge( digest_auth_mw: DigestAuthMiddleware, complete_challenge: DigestAuthChallenge ) -> DigestAuthMiddleware: """Return a digest auth middleware with pre-set challenge.""" digest_auth_mw._challenge = complete_challenge digest_auth_mw._last_nonce_bytes = complete_challenge["nonce"].encode("utf-8") digest_auth_mw._nonce_count = 0 return digest_auth_mw @pytest.fixture def mock_sha1_digest() -> Generator[mock.MagicMock, None, None]: """Mock SHA1 to return a predictable value for testing.""" mock_digest = mock.MagicMock(spec=sha1()) mock_digest.hexdigest.return_value = "deadbeefcafebabe" with mock.patch("hashlib.sha1", return_value=mock_digest) as patched: yield patched @pytest.fixture def mock_md5_digest() -> Generator[mock.MagicMock, None, None]: """Mock MD5 to return a predictable value for testing.""" mock_digest = mock.MagicMock(spec=md5()) mock_digest.hexdigest.return_value = "abcdef0123456789" with mock.patch("hashlib.md5", return_value=mock_digest) as patched: yield patched @pytest.mark.parametrize( ("response_status", "headers", "expected_result", "expected_challenge"), [ # Valid digest with all fields ( 401, { "www-authenticate": 'Digest realm="test", nonce="abc", ' 'qop="auth", opaque="xyz", algorithm=MD5' }, True, { "realm": "test", "nonce": "abc", "qop": "auth", "algorithm": "MD5", "opaque": "xyz", }, ), # Valid digest without opaque ( 401, {"www-authenticate": 'Digest realm="test", nonce="abc", qop="auth"'}, True, {"realm": "test", "nonce": "abc", "qop": "auth"}, ), # Valid digest with empty realm (RFC 7616 Section 3.3 allows this) ( 401, {"www-authenticate": 'Digest realm="", nonce="abc", qop="auth"'}, True, {"realm": "", "nonce": "abc", "qop": "auth"}, ), # Non-401 status (200, {}, False, {}), # No challenge should be set ], ) async def test_authenticate_scenarios( digest_auth_mw: DigestAuthMiddleware, response_status: int, headers: dict[str, str], expected_result: bool, expected_challenge: dict[str, str], ) -> None: """Test different authentication scenarios.""" response = mock.MagicMock(spec=ClientResponse) response.status = response_status response.headers = headers result = digest_auth_mw._authenticate(response) assert result == expected_result if expected_result: challenge_dict = dict(digest_auth_mw._challenge) for key, value in expected_challenge.items(): assert challenge_dict[key] == value @pytest.mark.parametrize( ("challenge", "expected_error"), [ ( DigestAuthChallenge(), "Malformed Digest auth challenge: Missing 'realm' parameter", ), ( DigestAuthChallenge(nonce="abc"), "Malformed Digest auth challenge: Missing 'realm' parameter", ), ( DigestAuthChallenge(realm="test"), "Malformed Digest auth challenge: Missing 'nonce' parameter", ), ( DigestAuthChallenge(realm="test", nonce=""), "Security issue: Digest auth challenge contains empty 'nonce' value", ), ], ) async def test_encode_validation_errors( digest_auth_mw: DigestAuthMiddleware, challenge: DigestAuthChallenge, expected_error: str, ) -> None: """Test validation errors when encoding digest auth headers.""" digest_auth_mw._challenge = challenge with pytest.raises(ClientError, match=expected_error): await digest_auth_mw._encode("GET", URL("http://example.com/resource"), b"") async def test_encode_digest_with_md5( auth_mw_with_challenge: DigestAuthMiddleware, ) -> None: header = await auth_mw_with_challenge._encode( "GET", URL("http://example.com/resource"), b"" ) assert header.startswith("Digest ") assert 'username="user"' in header assert "algorithm=MD5" in header @pytest.mark.parametrize( "algorithm", ["MD5-SESS", "SHA-SESS", "SHA-256-SESS", "SHA-512-SESS"] ) async def test_encode_digest_with_sess_algorithms( digest_auth_mw: DigestAuthMiddleware, qop_challenge: DigestAuthChallenge, algorithm: str, ) -> None: """Test that all session-based digest algorithms work correctly.""" # Create a modified challenge with the test algorithm challenge = qop_challenge.copy() challenge["algorithm"] = algorithm digest_auth_mw._challenge = challenge header = await digest_auth_mw._encode( "GET", URL("http://example.com/resource"), b"" ) assert f"algorithm={algorithm}" in header async def test_encode_unsupported_algorithm( digest_auth_mw: DigestAuthMiddleware, basic_challenge: DigestAuthChallenge ) -> None: """Test that unsupported algorithm raises ClientError.""" # Create a modified challenge with an unsupported algorithm challenge = basic_challenge.copy() challenge["algorithm"] = "UNSUPPORTED" digest_auth_mw._challenge = challenge with pytest.raises(ClientError, match="Unsupported hash algorithm"): await digest_auth_mw._encode("GET", URL("http://example.com/resource"), b"") @pytest.mark.parametrize("algorithm", ["MD5", "MD5-SESS", "SHA-256"]) async def test_encode_algorithm_case_preservation_uppercase( digest_auth_mw: DigestAuthMiddleware, qop_challenge: DigestAuthChallenge, algorithm: str, ) -> None: """Test that uppercase algorithm case is preserved in the response header.""" # Create a challenge with the specific algorithm case challenge = qop_challenge.copy() challenge["algorithm"] = algorithm digest_auth_mw._challenge = challenge header = await digest_auth_mw._encode( "GET", URL("http://example.com/resource"), b"" ) # The algorithm in the response should match the exact case from the challenge assert f"algorithm={algorithm}" in header @pytest.mark.parametrize("algorithm", ["md5", "MD5-sess", "sha-256"]) async def test_encode_algorithm_case_preservation_lowercase( digest_auth_mw: DigestAuthMiddleware, qop_challenge: DigestAuthChallenge, algorithm: str, ) -> None: """Test that lowercase/mixed-case algorithm is preserved in the response header.""" # Create a challenge with the specific algorithm case challenge = qop_challenge.copy() challenge["algorithm"] = algorithm digest_auth_mw._challenge = challenge header = await digest_auth_mw._encode( "GET", URL("http://example.com/resource"), b"" ) # The algorithm in the response should match the exact case from the challenge assert f"algorithm={algorithm}" in header # Also verify it's not the uppercase version assert f"algorithm={algorithm.upper()}" not in header async def test_invalid_qop_rejected( digest_auth_mw: DigestAuthMiddleware, basic_challenge: DigestAuthChallenge ) -> None: """Test that invalid Quality of Protection values are rejected.""" # Use bad QoP value to trigger error challenge = basic_challenge.copy() challenge["qop"] = "badvalue" challenge["algorithm"] = "MD5" digest_auth_mw._challenge = challenge # This should raise an error about unsupported QoP with pytest.raises(ClientError, match="Unsupported Quality of Protection"): await digest_auth_mw._encode("GET", URL("http://example.com"), b"") def compute_expected_digest( algorithm: str, username: str, password: str, realm: str, nonce: str, uri: str, method: str, qop: str, nc: str, cnonce: str, body: str = "", ) -> str: hash_fn = DigestFunctions[algorithm] def H(x: str) -> str: return hash_fn(x.encode()).hexdigest() def KD(secret: str, data: str) -> str: return H(f"{secret}:{data}") A1 = f"{username}:{realm}:{password}" HA1 = H(A1) if algorithm.upper().endswith("-SESS"): HA1 = H(f"{HA1}:{nonce}:{cnonce}") A2 = f"{method}:{uri}" if "auth-int" in qop: entity_hash = H(body) A2 = f"{A2}:{entity_hash}" HA2 = H(A2) if qop: return KD(HA1, f"{nonce}:{nc}:{cnonce}:{qop}:{HA2}") else: return KD(HA1, f"{nonce}:{HA2}") @pytest.mark.parametrize("qop", ["auth", "auth-int", "auth,auth-int", ""]) @pytest.mark.parametrize("algorithm", sorted(DigestFunctions.keys())) @pytest.mark.parametrize( ("body", "body_str"), [ (b"", ""), # Bytes case ( BytesIOPayload(io.BytesIO(b"this is a body")), "this is a body", ), # BytesIOPayload case ], ) async def test_digest_response_exact_match( qop: str, algorithm: str, body: Literal[b""] | BytesIOPayload, body_str: str, mock_sha1_digest: mock.MagicMock, ) -> None: # Fixed input values login = "user" password = "pass" realm = "example.com" nonce = "abc123nonce" cnonce = "deadbeefcafebabe" nc = 1 ncvalue = f"{nc+1:08x}" method = "GET" uri = "/secret" qop = "auth-int" if "auth-int" in qop else "auth" # Create the auth object auth = DigestAuthMiddleware(login, password) auth._challenge = DigestAuthChallenge( realm=realm, nonce=nonce, qop=qop, algorithm=algorithm ) auth._last_nonce_bytes = nonce.encode("utf-8") auth._nonce_count = nc header = await auth._encode(method, URL(f"http://host{uri}"), body) # Get expected digest expected = compute_expected_digest( algorithm=algorithm, username=login, password=password, realm=realm, nonce=nonce, uri=uri, method=method, qop=qop, nc=ncvalue, cnonce=cnonce, body=body_str, ) # Check that the response digest is exactly correct assert f'response="{expected}"' in header @pytest.mark.parametrize( ("header", "expected_result"), [ # Normal quoted values ( 'realm="example.com", nonce="12345", qop="auth"', {"realm": "example.com", "nonce": "12345", "qop": "auth"}, ), # Unquoted values ( "realm=example.com, nonce=12345, qop=auth", {"realm": "example.com", "nonce": "12345", "qop": "auth"}, ), # Mixed quoted/unquoted with commas in quoted values ( 'realm="ex,ample", nonce=12345, qop="auth", domain="/test"', { "realm": "ex,ample", "nonce": "12345", "qop": "auth", "domain": "/test", }, ), # Header with scheme ( 'Digest realm="example.com", nonce="12345", qop="auth"', {"realm": "example.com", "nonce": "12345", "qop": "auth"}, ), # No spaces after commas ( 'realm="test",nonce="123",qop="auth"', {"realm": "test", "nonce": "123", "qop": "auth"}, ), # Extra whitespace ( 'realm = "test" , nonce = "123"', {"realm": "test", "nonce": "123"}, ), # Escaped quotes ( 'realm="test\\"realm", nonce="123"', {"realm": 'test"realm', "nonce": "123"}, ), # Single quotes (treated as regular chars) ( "realm='test', nonce=123", {"realm": "'test'", "nonce": "123"}, ), # Empty header ("", {}), ], ids=[ "fully_quoted_header", "unquoted_header", "mixed_quoted_unquoted_with_commas", "header_with_scheme", "no_spaces_after_commas", "extra_whitespace", "escaped_quotes", "single_quotes_as_regular_chars", "empty_header", ], ) def test_parse_header_pairs(header: str, expected_result: dict[str, str]) -> None: """Test parsing HTTP header pairs with various formats.""" result = parse_header_pairs(header) assert result == expected_result def test_digest_auth_middleware_callable(digest_auth_mw: DigestAuthMiddleware) -> None: """Test that DigestAuthMiddleware is callable.""" assert callable(digest_auth_mw) def test_middleware_invalid_login() -> None: """Test that invalid login values raise errors.""" with pytest.raises(ValueError, match="None is not allowed as login value"): DigestAuthMiddleware(None, "pass") # type: ignore[arg-type] with pytest.raises(ValueError, match="None is not allowed as password value"): DigestAuthMiddleware("user", None) # type: ignore[arg-type] with pytest.raises(ValueError, match=r"A \":\" is not allowed in username"): DigestAuthMiddleware("user:name", "pass") async def test_escaping_quotes_in_auth_header() -> None: """Test that double quotes are properly escaped in auth header.""" auth = DigestAuthMiddleware('user"with"quotes', "pass") auth._challenge = DigestAuthChallenge( realm='realm"with"quotes', nonce='nonce"with"quotes', qop="auth", algorithm="MD5", opaque='opaque"with"quotes', ) header = await auth._encode("GET", URL("http://example.com/path"), b"") # Check that quotes are escaped in the header assert 'username="user\\"with\\"quotes"' in header assert 'realm="realm\\"with\\"quotes"' in header assert 'nonce="nonce\\"with\\"quotes"' in header assert 'opaque="opaque\\"with\\"quotes"' in header async def test_template_based_header_construction( auth_mw_with_challenge: DigestAuthMiddleware, mock_sha1_digest: mock.MagicMock, mock_md5_digest: mock.MagicMock, ) -> None: """Test that the template-based header construction works correctly.""" header = await auth_mw_with_challenge._encode( "GET", URL("http://example.com/test"), b"" ) # Split the header into scheme and parameters scheme, params_str = header.split(" ", 1) assert scheme == "Digest" # Parse the parameters into a dictionary params = { key: value[1:-1] if value.startswith('"') and value.endswith('"') else value for key, value in (param.split("=", 1) for param in params_str.split(", ")) } # Check all required fields are present assert "username" in params assert "realm" in params assert "nonce" in params assert "uri" in params assert "response" in params assert "algorithm" in params assert "qop" in params assert "nc" in params assert "cnonce" in params assert "opaque" in params # Check that fields are quoted correctly quoted_fields = [ "username", "realm", "nonce", "uri", "response", "opaque", "cnonce", ] unquoted_fields = ["algorithm", "qop", "nc"] # Re-check the original header for proper quoting for field in quoted_fields: assert f'{field}="{params[field]}"' in header for field in unquoted_fields: assert f"{field}={params[field]}" in header # Check specific values assert params["username"] == "user" assert params["realm"] == "test" assert params["algorithm"] == "MD5" assert params["nc"] == "00000001" # nonce_count = 1 (incremented from 0) assert params["uri"] == "/test" # path component of URL @pytest.mark.parametrize( ("test_string", "expected_escaped", "description"), [ ('value"with"quotes', 'value\\"with\\"quotes', "Basic string with quotes"), ("", "", "Empty string"), ("no quotes", "no quotes", "String without quotes"), ('with"one"quote', 'with\\"one\\"quote', "String with one quoted segment"), ( 'many"quotes"in"string', 'many\\"quotes\\"in\\"string', "String with multiple quoted segments", ), ('""', '\\"\\"', "Just double quotes"), ('"', '\\"', "Single double quote"), ('already\\"escaped', 'already\\\\"escaped', "Already escaped quotes"), ], ) def test_quote_escaping_functions( test_string: str, expected_escaped: str, description: str ) -> None: """Test that escape_quotes and unescape_quotes work correctly.""" # Test escaping escaped = escape_quotes(test_string) assert escaped == expected_escaped # Test unescaping (should return to original) unescaped = unescape_quotes(escaped) assert unescaped == test_string # Test that they're inverse operations assert unescape_quotes(escape_quotes(test_string)) == test_string async def test_middleware_retry_on_401( aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware ) -> None: """Test that the middleware retries on 401 errors.""" request_count = 0 async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 if request_count == 1: # First request returns 401 with digest challenge challenge = 'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) # Second request should have Authorization header auth_header = request.headers.get(hdrs.AUTHORIZATION) if auth_header and auth_header.startswith("Digest "): # Return success response return Response(text="OK") # This branch should not be reached in the tests assert False, "This branch should not be reached" app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text_content = await resp.text() assert text_content == "OK" assert request_count == 2 # Initial request + retry with auth async def test_digest_auth_no_qop( aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware, no_qop_challenge: DigestAuthChallenge, mock_sha1_digest: mock.MagicMock, ) -> None: """Test digest auth with a server that doesn't provide a QoP parameter.""" request_count = 0 realm = no_qop_challenge["realm"] nonce = no_qop_challenge["nonce"] algorithm = no_qop_challenge["algorithm"] username = "user" password = "pass" uri = "/" async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 if request_count == 1: # First request returns 401 with digest challenge without qop challenge = ( f'Digest realm="{realm}", nonce="{nonce}", algorithm={algorithm}' ) return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) # Second request should have Authorization header auth_header = request.headers.get(hdrs.AUTHORIZATION) assert auth_header and auth_header.startswith("Digest ") # Successful auth should have no qop param assert "qop=" not in auth_header assert "nc=" not in auth_header assert "cnonce=" not in auth_header expected_digest = compute_expected_digest( algorithm=algorithm, username=username, password=password, realm=realm, nonce=nonce, uri=uri, method="GET", qop="", # This is the key part - explicitly setting qop="" nc="", # Not needed for non-qop digest cnonce="", # Not needed for non-qop digest ) # We mock the cnonce, so we can check the expected digest assert expected_digest in auth_header return Response(text="OK") app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text_content = await resp.text() assert text_content == "OK" assert request_count == 2 # Initial request + retry with auth async def test_digest_auth_without_opaque( aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware ) -> None: """Test digest auth with a server that doesn't provide an opaque parameter.""" request_count = 0 async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 if request_count == 1: # First request returns 401 with digest challenge without opaque challenge = ( 'Digest realm="test-realm", nonce="testnonce", ' 'qop="auth", algorithm=MD5' ) return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) # Second request should have Authorization header auth_header = request.headers.get(hdrs.AUTHORIZATION) assert auth_header and auth_header.startswith("Digest ") # Successful auth should have no opaque param assert "opaque=" not in auth_header return Response(text="OK") app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text_content = await resp.text() assert text_content == "OK" assert request_count == 2 # Initial request + retry with auth @pytest.mark.parametrize( "www_authenticate", [ None, "DigestWithoutSpace", 'Basic realm="test"', "Digest ", "Digest =invalid, format", ], ) async def test_auth_header_no_retry( aiohttp_server: AiohttpServer, www_authenticate: str, digest_auth_mw: DigestAuthMiddleware, ) -> None: """Test that middleware doesn't retry with invalid WWW-Authenticate headers.""" request_count = 0 async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 # First (and only) request returns 401 headers = {} if www_authenticate is not None: headers["WWW-Authenticate"] = www_authenticate # Use a custom HTTPUnauthorized instead of the helper since # we're specifically testing malformed headers return Response(status=401, headers=headers, text="Unauthorized") app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 401 # No retry should happen assert request_count == 1 async def test_direct_success_no_auth_needed( aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware ) -> None: """Test middleware with a direct 200 response with no auth challenge.""" request_count = 0 async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 # Return success without auth challenge return Response(text="OK") app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: async with session.get(server.make_url("/")) as resp: text = await resp.text() assert resp.status == 200 assert text == "OK" # Verify only one request was made assert request_count == 1 async def test_no_retry_on_second_401( aiohttp_server: AiohttpServer, digest_auth_mw: DigestAuthMiddleware ) -> None: """Test digest auth does not retry on second 401.""" request_count = 0 async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 # Always return 401 challenge challenge = 'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) # Create a session that uses the digest auth middleware async with ClientSession(middlewares=(digest_auth_mw,)) as session: async with session.get(server.make_url("/")) as resp: await resp.text() assert resp.status == 401 # Verify we made exactly 2 requests (initial + 1 retry) assert request_count == 2 async def test_preemptive_auth_disabled( aiohttp_server: AiohttpServer, ) -> None: """Test that preemptive authentication can be disabled.""" digest_auth_mw = DigestAuthMiddleware("user", "pass", preemptive=False) request_count = 0 auth_headers = [] async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 auth_headers.append(request.headers.get(hdrs.AUTHORIZATION)) if not request.headers.get(hdrs.AUTHORIZATION): # Return 401 with digest challenge challenge = 'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) return Response(text="OK") app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: # First request will get 401 and store challenge async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "OK" # Second request should NOT send auth preemptively (preemptive=False) async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "OK" # With preemptive disabled, each request needs 401 challenge first assert request_count == 4 # 2 requests * 2 (401 + retry) assert auth_headers[0] is None # First request has no auth assert auth_headers[1] is not None # Second request has auth after 401 assert auth_headers[2] is None # Third request has no auth (preemptive disabled) assert auth_headers[3] is not None # Fourth request has auth after 401 async def test_preemptive_auth_with_stale_nonce( aiohttp_server: AiohttpServer, ) -> None: """Test preemptive auth handles stale nonce responses correctly.""" digest_auth_mw = DigestAuthMiddleware("user", "pass", preemptive=True) request_count = 0 current_nonce = 0 async def handler(request: Request) -> Response: nonlocal request_count, current_nonce request_count += 1 auth_header = request.headers.get(hdrs.AUTHORIZATION) if not auth_header: # First request without auth current_nonce = 1 challenge = f'Digest realm="test", nonce="nonce{current_nonce}", qop="auth", algorithm=MD5' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) # For the second set of requests, always consider the first nonce stale if request_count == 3 and current_nonce == 1: # Stale nonce - request new auth with stale=true current_nonce = 2 challenge = f'Digest realm="test", nonce="nonce{current_nonce}", qop="auth", algorithm=MD5, stale=true' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized - Stale nonce", ) return Response(text="OK") app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: # First request - will get 401, then retry with auth async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "OK" # Second request - will use preemptive auth with nonce1, get 401 stale, retry with nonce2 async with session.get(server.make_url("/")) as resp: assert resp.status == 200 text = await resp.text() assert text == "OK" # Verify the expected flow: # Request 1: no auth -> 401 # Request 2: retry with auth -> 200 # Request 3: preemptive auth with old nonce -> 401 stale # Request 4: retry with new nonce -> 200 assert request_count == 4 async def test_preemptive_auth_updates_nonce_count( aiohttp_server: AiohttpServer, ) -> None: """Test that preemptive auth properly increments nonce count.""" digest_auth_mw = DigestAuthMiddleware("user", "pass", preemptive=True) request_count = 0 nonce_counts = [] async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 auth_header = request.headers.get(hdrs.AUTHORIZATION) if not auth_header: # First request without auth challenge = 'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) # Extract nc (nonce count) from auth header nc_match = auth_header.split("nc=")[1].split(",")[0].strip() nonce_counts.append(nc_match) return Response(text="OK") app = Application() app.router.add_get("/", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: # Make multiple requests to see nonce count increment for _ in range(3): async with session.get(server.make_url("/")) as resp: assert resp.status == 200 await resp.text() # First request has no auth, then gets 401 and retries with nc=00000001 # Second and third requests use preemptive auth with nc=00000002 and nc=00000003 assert len(nonce_counts) == 3 assert nonce_counts[0] == "00000001" assert nonce_counts[1] == "00000002" assert nonce_counts[2] == "00000003" async def test_preemptive_auth_respects_protection_space( aiohttp_server: AiohttpServer, ) -> None: """Test that preemptive auth only applies to URLs within the protection space.""" digest_auth_mw = DigestAuthMiddleware("user", "pass", preemptive=True) request_count = 0 auth_headers = [] requested_paths = [] async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 auth_headers.append(request.headers.get(hdrs.AUTHORIZATION)) requested_paths.append(request.path) if not request.headers.get(hdrs.AUTHORIZATION): # Return 401 with digest challenge including domain parameter challenge = 'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5, domain="/api /admin"' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) return Response(text="OK") app = Application() app.router.add_get("/api/endpoint", handler) app.router.add_get("/admin/panel", handler) app.router.add_get("/public/page", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: # First request to /api/endpoint - should get 401 and retry with auth async with session.get(server.make_url("/api/endpoint")) as resp: assert resp.status == 200 # Second request to /api/endpoint - should use preemptive auth (in protection space) async with session.get(server.make_url("/api/endpoint")) as resp: assert resp.status == 200 # Third request to /admin/panel - should use preemptive auth (in protection space) async with session.get(server.make_url("/admin/panel")) as resp: assert resp.status == 200 # Fourth request to /public/page - should NOT use preemptive auth (outside protection space) async with session.get(server.make_url("/public/page")) as resp: assert resp.status == 200 # Verify auth headers assert auth_headers[0] is None # First request to /api/endpoint - no auth assert auth_headers[1] is not None # Retry with auth assert ( auth_headers[2] is not None ) # Second request to /api/endpoint - preemptive auth assert auth_headers[3] is not None # Request to /admin/panel - preemptive auth assert auth_headers[4] is None # First request to /public/page - no preemptive auth assert auth_headers[5] is not None # Retry with auth # Verify paths assert requested_paths == [ "/api/endpoint", # Initial request "/api/endpoint", # Retry with auth "/api/endpoint", # Second request with preemptive auth "/admin/panel", # Request with preemptive auth "/public/page", # Initial request (no preemptive auth) "/public/page", # Retry with auth ] async def test_preemptive_auth_with_absolute_domain_uris( aiohttp_server: AiohttpServer, ) -> None: """Test preemptive auth with absolute URIs in domain parameter.""" digest_auth_mw = DigestAuthMiddleware("user", "pass", preemptive=True) request_count = 0 auth_headers = [] async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 auth_headers.append(request.headers.get(hdrs.AUTHORIZATION)) if not request.headers.get(hdrs.AUTHORIZATION): # Return 401 with digest challenge including absolute URI in domain server_url = str(request.url.with_path("/protected")) challenge = f'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5, domain="{server_url}"' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) return Response(text="OK") app = Application() app.router.add_get("/protected/resource", handler) app.router.add_get("/unprotected/resource", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: # First request to protected resource async with session.get(server.make_url("/protected/resource")) as resp: assert resp.status == 200 # Second request to protected resource - should use preemptive auth async with session.get(server.make_url("/protected/resource")) as resp: assert resp.status == 200 # Request to unprotected resource - should NOT use preemptive auth async with session.get(server.make_url("/unprotected/resource")) as resp: assert resp.status == 200 # Verify auth pattern assert auth_headers[0] is None # First request - no auth assert auth_headers[1] is not None # Retry with auth assert auth_headers[2] is not None # Second request - preemptive auth assert auth_headers[3] is None # Unprotected resource - no preemptive auth assert auth_headers[4] is not None # Retry with auth async def test_preemptive_auth_without_domain_uses_origin( aiohttp_server: AiohttpServer, ) -> None: """Test that preemptive auth without domain parameter applies to entire origin.""" digest_auth_mw = DigestAuthMiddleware("user", "pass", preemptive=True) request_count = 0 auth_headers = [] async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 auth_headers.append(request.headers.get(hdrs.AUTHORIZATION)) if not request.headers.get(hdrs.AUTHORIZATION): # Return 401 with digest challenge without domain parameter challenge = 'Digest realm="test", nonce="abc123", qop="auth", algorithm=MD5' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) return Response(text="OK") app = Application() app.router.add_get("/path1", handler) app.router.add_get("/path2", handler) server = await aiohttp_server(app) async with ClientSession(middlewares=(digest_auth_mw,)) as session: # First request async with session.get(server.make_url("/path1")) as resp: assert resp.status == 200 # Second request to different path - should still use preemptive auth async with session.get(server.make_url("/path2")) as resp: assert resp.status == 200 # Verify auth pattern assert auth_headers[0] is None # First request - no auth assert auth_headers[1] is not None # Retry with auth assert ( auth_headers[2] is not None ) # Second request - preemptive auth (entire origin) @pytest.mark.parametrize( ("status", "headers", "expected"), [ (200, {}, False), (401, {"www-authenticate": ""}, False), (401, {"www-authenticate": "DigestWithoutSpace"}, False), (401, {"www-authenticate": "Basic realm=test"}, False), (401, {"www-authenticate": "Digest "}, False), (401, {"www-authenticate": "Digest =invalid, format"}, False), ], ids=[ "different_status_code", "empty_www_authenticate_header", "no_space_after_scheme", "different_scheme", "empty_parameters", "malformed_parameters", ], ) def test_authenticate_with_malformed_headers( digest_auth_mw: DigestAuthMiddleware, status: int, headers: dict[str, str], expected: bool, ) -> None: """Test _authenticate method with various edge cases.""" response = mock.MagicMock(spec=ClientResponse) response.status = status response.headers = headers result = digest_auth_mw._authenticate(response) assert result == expected @pytest.mark.parametrize( ("protection_space_url", "request_url", "expected"), [ # Exact match ("http://example.com/app1", "http://example.com/app1", True), # Path with trailing slash should match ("http://example.com/app1", "http://example.com/app1/", True), # Subpaths should match ("http://example.com/app1", "http://example.com/app1/resource", True), ("http://example.com/app1", "http://example.com/app1/sub/path", True), # Should NOT match different paths that start with same prefix ("http://example.com/app1", "http://example.com/app1xx", False), ("http://example.com/app1", "http://example.com/app123", False), # Protection space with trailing slash ("http://example.com/app1/", "http://example.com/app1/", True), ("http://example.com/app1/", "http://example.com/app1/resource", True), ( "http://example.com/app1/", "http://example.com/app1", False, ), # No trailing slash # Root protection space ("http://example.com/", "http://example.com/", True), ("http://example.com/", "http://example.com/anything", True), ("http://example.com/", "http://example.com", False), # No trailing slash # Different origins should not match ("http://example.com/app1", "https://example.com/app1", False), ("http://example.com/app1", "http://other.com/app1", False), ("http://example.com:8080/app1", "http://example.com/app1", False), ], ids=[ "exact_match", "path_with_trailing_slash", "subpath_match", "deep_subpath_match", "no_match_app1xx", "no_match_app123", "protection_with_slash_exact", "protection_with_slash_subpath", "protection_with_slash_no_match_without", "root_protection_exact", "root_protection_subpath", "root_protection_no_match_without_slash", "different_scheme", "different_host", "different_port", ], ) def test_in_protection_space( digest_auth_mw: DigestAuthMiddleware, protection_space_url: str, request_url: str, expected: bool, ) -> None: """Test _in_protection_space method with various URL patterns.""" digest_auth_mw._protection_space = [protection_space_url] result = digest_auth_mw._in_protection_space(URL(request_url)) assert result == expected def test_in_protection_space_multiple_spaces( digest_auth_mw: DigestAuthMiddleware, ) -> None: """Test _in_protection_space with multiple protection spaces.""" digest_auth_mw._protection_space = [ "http://example.com/api", "http://example.com/admin/", "http://example.com/secure/area", ] # Test various URLs assert digest_auth_mw._in_protection_space(URL("http://example.com/api")) is True assert digest_auth_mw._in_protection_space(URL("http://example.com/api/v1")) is True assert ( digest_auth_mw._in_protection_space(URL("http://example.com/admin/panel")) is True ) assert ( digest_auth_mw._in_protection_space( URL("http://example.com/secure/area/resource") ) is True ) # These should not match assert digest_auth_mw._in_protection_space(URL("http://example.com/apiv2")) is False assert ( digest_auth_mw._in_protection_space(URL("http://example.com/admin")) is False ) # No trailing slash assert ( digest_auth_mw._in_protection_space(URL("http://example.com/secure")) is False ) assert digest_auth_mw._in_protection_space(URL("http://example.com/other")) is False async def test_case_sensitive_algorithm_server( aiohttp_server: AiohttpServer, ) -> None: """Test authentication with a server that requires exact algorithm case matching. This simulates servers like Prusa printers that expect the algorithm to be returned with the exact same case as sent in the challenge. """ digest_auth_mw = DigestAuthMiddleware("testuser", "testpass") request_count = 0 auth_algorithms: list[str] = [] async def handler(request: Request) -> Response: nonlocal request_count request_count += 1 if not (auth_header := request.headers.get(hdrs.AUTHORIZATION)): # Send challenge with lowercase-sess algorithm (like Prusa) challenge = 'Digest realm="Administrator", nonce="test123", qop="auth", algorithm="MD5-sess", opaque="xyz123"' return Response( status=401, headers={"WWW-Authenticate": challenge}, text="Unauthorized", ) # Extract algorithm from auth response algo_match = re.search(r"algorithm=([^,\s]+)", auth_header) assert algo_match is not None auth_algorithms.append(algo_match.group(1)) # Case-sensitive server: only accept exact case match assert "algorithm=MD5-sess" in auth_header return Response(text="Success") app = Application() app.router.add_get("/api/test", handler) server = await aiohttp_server(app) async with ( ClientSession(middlewares=(digest_auth_mw,)) as session, session.get(server.make_url("/api/test")) as resp, ): assert resp.status == 200 text = await resp.text() assert text == "Success" # Verify the middleware preserved the exact algorithm case assert request_count == 2 # Initial 401 + successful retry assert len(auth_algorithms) == 1 assert auth_algorithms[0] == "MD5-sess" # Not "MD5-SESS" def test_regex_performance() -> None: """Test that the regex pattern doesn't suffer from ReDoS issues.""" REGEX_TIME_THRESHOLD_SECONDS = 0.08 value = "0" * 54773 + "\\0=a" start = time.perf_counter() matches = _HEADER_PAIRS_PATTERN.findall(value) elapsed = time.perf_counter() - start # If this is taking more time, there's probably a performance/ReDoS issue. assert elapsed < REGEX_TIME_THRESHOLD_SECONDS, ( f"Regex took {elapsed * 1000:.1f}ms, " f"expected <{REGEX_TIME_THRESHOLD_SECONDS * 1000:.0f}ms - potential ReDoS issue" ) # This example shouldn't produce a match either. assert not matches ================================================ FILE: tests/test_client_proto.py ================================================ import asyncio from unittest import mock from multidict import CIMultiDict from pytest_mock import MockerFixture from yarl import URL from aiohttp import http from aiohttp.client_exceptions import ClientOSError, ServerDisconnectedError from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientResponse from aiohttp.helpers import TimerNoop from aiohttp.http_parser import RawResponseMessage async def test_force_close(loop: asyncio.AbstractEventLoop) -> None: """Ensure that the force_close method sets the should_close attribute to True. This is used externally in aiodocker https://github.com/aio-libs/aiodocker/issues/920 """ proto = ResponseHandler(loop=loop) proto.force_close() assert proto.should_close async def test_oserror(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) proto.connection_lost(OSError()) assert proto.should_close assert isinstance(proto.exception(), ClientOSError) async def test_pause_resume_on_error(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) proto.pause_reading() assert proto._reading_paused proto.resume_reading() assert not proto._reading_paused async def test_client_proto_bad_message(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) proto.set_response_params() proto.data_received(b"HTTP\r\n\r\n") assert proto.should_close assert transport.close.called assert isinstance(proto.exception(), http.HttpProcessingError) async def test_uncompleted_message(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) proto.set_response_params(read_until_eof=True) proto.data_received( b"HTTP/1.1 301 Moved Permanently\r\nLocation: http://python.org/" ) proto.connection_lost(None) exc = proto.exception() assert isinstance(exc, ServerDisconnectedError) assert isinstance(exc.message, RawResponseMessage) assert exc.message.code == 301 assert dict(exc.message.headers) == {"Location": "http://python.org/"} async def test_data_received_after_close(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) proto.set_response_params(read_until_eof=True) proto.close() assert transport.close.called transport.close.reset_mock() proto.data_received(b"HTTP\r\n\r\n") assert proto.should_close assert not transport.close.called assert isinstance(proto.exception(), http.HttpProcessingError) async def test_multiple_responses_one_byte_at_a_time( loop: asyncio.AbstractEventLoop, ) -> None: proto = ResponseHandler(loop=loop) proto.connection_made(mock.Mock()) conn = mock.Mock(protocol=proto) proto.set_response_params(read_until_eof=True) for _ in range(2): messages = ( b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab" b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\ncd" b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nef" ) for i in range(len(messages)): proto.data_received(messages[i : i + 1]) expected = [b"ab", b"cd", b"ef"] url = URL("http://def-cl-resp.org") for payload in expected: response = ClientResponse( "get", url, writer=mock.Mock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) await response.start(conn) await response.read() == payload async def test_unexpected_exception_during_data_received( loop: asyncio.AbstractEventLoop, ) -> None: proto = ResponseHandler(loop=loop) class PatchableHttpResponseParser(http.HttpResponseParser): """Subclass of HttpResponseParser to make it patchable.""" with mock.patch( "aiohttp.client_proto.HttpResponseParser", PatchableHttpResponseParser ): proto.connection_made(mock.Mock()) conn = mock.Mock(protocol=proto) proto.set_response_params(read_until_eof=True) proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab") url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=mock.Mock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) await response.start(conn) await response.read() == b"ab" with mock.patch.object(proto._parser, "feed_data", side_effect=ValueError): proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\ncd") assert isinstance(proto.exception(), http.HttpProcessingError) async def test_client_protocol_readuntil_eof(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) transport = mock.Mock() proto.connection_made(transport) conn = mock.Mock() conn.protocol = proto proto.data_received(b"HTTP/1.1 200 Ok\r\n\r\n") url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=mock.Mock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) proto.set_response_params(read_until_eof=True) await response.start(conn) assert not response.content.is_eof() proto.data_received(b"0000") data = await response.content.readany() assert data == b"0000" proto.data_received(b"1111") data = await response.content.readany() assert data == b"1111" proto.connection_lost(None) assert response.content.is_eof() async def test_empty_data(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) proto.data_received(b"") # do nothing async def test_schedule_timeout(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) proto.set_response_params(read_timeout=1) assert proto._read_timeout_handle is None proto.start_timeout() assert proto._read_timeout_handle is not None async def test_drop_timeout(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) proto.set_response_params(read_timeout=1) proto.start_timeout() assert proto._read_timeout_handle is not None proto._drop_timeout() assert proto._read_timeout_handle is None async def test_reschedule_timeout(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) proto.set_response_params(read_timeout=1) proto.start_timeout() assert proto._read_timeout_handle is not None h = proto._read_timeout_handle proto._reschedule_timeout() assert proto._read_timeout_handle is not None assert proto._read_timeout_handle is not h async def test_eof_received(loop: asyncio.AbstractEventLoop) -> None: proto = ResponseHandler(loop=loop) proto.set_response_params(read_timeout=1) proto.start_timeout() assert proto._read_timeout_handle is not None proto.eof_received() assert proto._read_timeout_handle is None async def test_connection_lost_sets_transport_to_none( loop: asyncio.AbstractEventLoop, mocker: MockerFixture ) -> None: """Ensure that the transport is set to None when the connection is lost. This ensures the writer knows that the connection is closed. """ proto = ResponseHandler(loop=loop) proto.connection_made(mocker.Mock()) assert proto.transport is not None proto.connection_lost(OSError()) assert proto.transport is None async def test_connection_lost_exception_is_marked_retrieved( loop: asyncio.AbstractEventLoop, ) -> None: """Test that connection_lost properly handles exceptions without warnings.""" proto = ResponseHandler(loop=loop) proto.connection_made(mock.Mock()) # Access closed property before connection_lost to ensure future is created closed_future = proto.closed assert closed_future is not None # Simulate an SSL shutdown timeout error ssl_error = TimeoutError("SSL shutdown timed out") proto.connection_lost(ssl_error) # Verify the exception was set on the closed future assert closed_future.done() exc = closed_future.exception() assert exc is not None assert "Connection lost: SSL shutdown timed out" in str(exc) assert exc.__cause__ is ssl_error async def test_closed_property_lazy_creation( loop: asyncio.AbstractEventLoop, ) -> None: """Test that closed future is created lazily.""" proto = ResponseHandler(loop=loop) # Initially, the closed future should not be created assert proto._closed is None # Accessing the property should create the future closed_future = proto.closed assert closed_future is not None assert isinstance(closed_future, asyncio.Future) assert not closed_future.done() # Subsequent access should return the same future assert proto.closed is closed_future async def test_closed_property_after_connection_lost( loop: asyncio.AbstractEventLoop, ) -> None: """Test that closed property returns None after connection_lost if never accessed.""" proto = ResponseHandler(loop=loop) proto.connection_made(mock.Mock()) # Don't access proto.closed before connection_lost proto.connection_lost(None) # After connection_lost, closed should return None if it was never accessed assert proto.closed is None async def test_abort(loop: asyncio.AbstractEventLoop) -> None: """Test the abort() method.""" proto = ResponseHandler(loop=loop) # Create a mock transport transport = mock.Mock() proto.connection_made(transport) # Set up some state proto._payload = mock.Mock() # Mock _drop_timeout method using patch.object with mock.patch.object(proto, "_drop_timeout") as mock_drop_timeout: # Call abort proto.abort() # Verify transport.abort() was called transport.abort.assert_called_once() # Verify cleanup assert proto.transport is None assert proto._payload is None assert proto._exception is None # type: ignore[unreachable] mock_drop_timeout.assert_called_once() async def test_abort_without_transport(loop: asyncio.AbstractEventLoop) -> None: """Test abort() when transport is None.""" proto = ResponseHandler(loop=loop) # Mock _drop_timeout method using patch.object with mock.patch.object(proto, "_drop_timeout") as mock_drop_timeout: # Call abort without transport proto.abort() # Should not raise and should still clean up assert proto._exception is None mock_drop_timeout.assert_not_called() ================================================ FILE: tests/test_client_request.py ================================================ import asyncio import hashlib import io import pathlib import sys from collections.abc import AsyncIterator, Callable, Iterable from http.cookies import BaseCookie, SimpleCookie from typing import Any from unittest import mock import pytest from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL import aiohttp from aiohttp import BaseConnector, hdrs, payload from aiohttp.abc import AbstractStreamWriter from aiohttp.base_protocol import BaseProtocol from aiohttp.client_exceptions import ClientConnectionError from aiohttp.client_reqrep import ( ClientRequest, ClientRequestArgs, ClientResponse, Fingerprint, _gen_default_accept_encoding, ) from aiohttp.compression_utils import ZLibBackend from aiohttp.connector import Connection from aiohttp.hdrs import METH_DELETE from aiohttp.helpers import TimerNoop from aiohttp.http import HttpVersion10, HttpVersion11, StreamWriter from aiohttp.multipart import MultipartWriter if sys.version_info >= (3, 11): from typing import Unpack _RequestMaker = Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest] else: _RequestMaker = Any class WriterMock(mock.AsyncMock): def add_done_callback(self, cb: Callable[[], None]) -> None: """Dummy method.""" def remove_done_callback(self, cb: Callable[[], None]) -> None: """Dummy method.""" ALL_METHODS = frozenset( (*ClientRequest.GET_METHODS, *ClientRequest.POST_METHODS, METH_DELETE) ) @pytest.fixture def buf() -> bytearray: return bytearray() @pytest.fixture def protocol( loop: asyncio.AbstractEventLoop, transport: asyncio.Transport ) -> BaseProtocol: protocol = mock.Mock() protocol.transport = transport protocol._drain_helper.return_value = loop.create_future() protocol._drain_helper.return_value.set_result(None) return protocol @pytest.fixture def transport(buf: bytearray) -> mock.Mock: transport = mock.create_autospec(asyncio.Transport, spec_set=True, instance=True) def write(chunk: bytes) -> None: buf.extend(chunk) def writelines(chunks: Iterable[bytes]) -> None: for chunk in chunks: buf.extend(chunk) transport.write.side_effect = write transport.writelines.side_effect = writelines transport.is_closing.return_value = False return transport # type: ignore[no-any-return] @pytest.fixture def conn(transport: asyncio.Transport, protocol: BaseProtocol) -> Connection: return mock.Mock(transport=transport, protocol=protocol) async def test_method1(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org/")) assert req.method == "GET" async def test_method2(make_client_request: _RequestMaker) -> None: req = make_client_request("head", URL("http://python.org/")) assert req.method == "HEAD" async def test_method3(make_client_request: _RequestMaker) -> None: req = make_client_request("HEAD", URL("http://python.org/")) assert req.method == "HEAD" async def test_method_invalid(make_client_request: _RequestMaker) -> None: with pytest.raises(ValueError, match="Method cannot contain non-token characters"): make_client_request("METHOD WITH\nWHITESPACES", URL("http://python.org/")) async def test_version_1_0(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org/"), version=HttpVersion10) assert req.version == (1, 0) async def test_version_default(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org/")) assert req.version == (1, 1) async def test_request_info(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org/")) url = URL("http://python.org/") h = CIMultiDictProxy(req.headers) # Create a response to test request_info resp = req.response_class( "GET", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=req.loop, session=None, request_headers=req.headers, original_url=url, ) assert resp.request_info == aiohttp.RequestInfo(url, "GET", h, url) async def test_request_info_with_fragment(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org/#urlfragment")) h = CIMultiDictProxy(req.headers) # Create a response to test request_info resp = req.response_class( "GET", URL("http://python.org/"), writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=req.loop, session=None, request_headers=req.headers, original_url=URL("http://python.org/#urlfragment"), ) assert resp.request_info == aiohttp.RequestInfo( URL("http://python.org/"), "GET", h, URL("http://python.org/#urlfragment"), ) async def test_host_port_default_http(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org/")) assert req.url.host == "python.org" assert req.url.port == 80 assert not req.is_ssl() async def test_host_port_default_https(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("https://python.org/")) assert req.url.host == "python.org" assert req.url.port == 443 assert req.is_ssl() async def test_host_port_nondefault_http(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org:960/")) assert req.url.host == "python.org" assert req.url.port == 960 assert not req.is_ssl() async def test_host_port_nondefault_https(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("https://python.org:960/")) assert req.url.host == "python.org" assert req.url.port == 960 assert req.is_ssl() async def test_host_port_default_ws(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("ws://python.org/")) assert req.url.host == "python.org" assert req.url.port == 80 assert not req.is_ssl() async def test_host_port_default_wss(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("wss://python.org/")) assert req.url.host == "python.org" assert req.url.port == 443 assert req.is_ssl() async def test_host_port_nondefault_ws(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("ws://python.org:960/")) assert req.url.host == "python.org" assert req.url.port == 960 assert not req.is_ssl() async def test_host_port_nondefault_wss(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("wss://python.org:960/")) assert req.url.host == "python.org" assert req.url.port == 960 assert req.is_ssl() async def test_host_port_none_port(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("unix://localhost/path")) assert req.headers[hdrs.HOST] == "localhost" async def test_host_port_err(make_client_request: _RequestMaker) -> None: with pytest.raises(ValueError): make_client_request("get", URL("http://python.org:123e/")) async def test_hostname_err(make_client_request: _RequestMaker) -> None: with pytest.raises(ValueError): make_client_request("get", URL("http://:8080/")) async def test_host_header_host_first(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org/")) assert list(req.headers)[0] == hdrs.HOST async def test_host_header_host_without_port( make_client_request: _RequestMaker, ) -> None: req = make_client_request("get", URL("http://python.org/")) assert req.headers[hdrs.HOST] == "python.org" async def test_host_header_host_with_default_port( make_client_request: _RequestMaker, ) -> None: req = make_client_request("get", URL("http://python.org:80/")) assert req.headers[hdrs.HOST] == "python.org" async def test_host_header_host_with_nondefault_port( make_client_request: _RequestMaker, ) -> None: req = make_client_request("get", URL("http://python.org:99/")) assert req.headers["HOST"] == "python.org:99" async def test_host_header_host_idna_encode(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://xn--9caa.com")) assert req.headers["HOST"] == "xn--9caa.com" async def test_host_header_host_unicode(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://éé.com")) assert req.headers["HOST"] == "xn--9caa.com" async def test_host_header_explicit_host(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL("http://python.org/"), headers=CIMultiDict({"host": "example.com"}) ) assert req.headers["HOST"] == "example.com" async def test_host_header_explicit_host_with_port( make_client_request: _RequestMaker, ) -> None: req = make_client_request( "get", URL("http://python.org/"), headers=CIMultiDict({"host": "example.com:99"}), ) assert req.headers["HOST"] == "example.com:99" async def test_host_header_ipv4(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://127.0.0.2")) assert req.headers["HOST"] == "127.0.0.2" async def test_host_header_ipv6(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://[::2]")) assert req.headers["HOST"] == "[::2]" async def test_host_header_ipv4_with_port(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://127.0.0.2:99")) assert req.headers["HOST"] == "127.0.0.2:99" async def test_host_header_ipv6_with_port(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://[::2]:99")) assert req.headers["HOST"] == "[::2]:99" @pytest.mark.parametrize( ("url", "headers", "expected"), ( pytest.param( "http://localhost.", CIMultiDict(), "localhost", id="dot only at the end" ), pytest.param( "http://python.org.", CIMultiDict(), "python.org", id="single dot" ), pytest.param( "http://python.org.:99", CIMultiDict(), "python.org:99", id="single dot with port", ), pytest.param( "http://python.org...:99", CIMultiDict(), "python.org:99", id="multiple dots with port", ), pytest.param( "http://python.org.:99", CIMultiDict({"host": "example.com.:99"}), "example.com.:99", id="explicit host header", ), pytest.param("https://python.org.", CIMultiDict(), "python.org", id="https"), pytest.param("https://...", CIMultiDict(), "", id="only dots"), pytest.param( "http://príklad.example.org.:99", CIMultiDict(), "xn--prklad-4va.example.org:99", id="single dot with port idna", ), ), ) async def test_host_header_fqdn( # type: ignore[misc] make_client_request: _RequestMaker, url: str, headers: CIMultiDict[str], expected: str, ) -> None: req = make_client_request("get", URL(url), headers=headers) assert req.headers["HOST"] == expected async def test_default_headers_useragent(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org/")) assert "SERVER" not in req.headers assert "USER-AGENT" in req.headers async def test_default_headers_useragent_custom( make_client_request: _RequestMaker, ) -> None: req = make_client_request( "get", URL("http://python.org/"), headers=CIMultiDict({"user-agent": "my custom agent"}), ) assert "USER-Agent" in req.headers assert "my custom agent" == req.headers["User-Agent"] async def test_skip_default_useragent_header( make_client_request: _RequestMaker, ) -> None: req = make_client_request( "get", URL("http://python.org/"), skip_auto_headers={istr("user-agent")} ) assert "User-Agent" not in req.headers async def test_headers(make_client_request: _RequestMaker) -> None: req = make_client_request( "post", URL("http://python.org/"), headers=CIMultiDict({hdrs.CONTENT_TYPE: "text/plain"}), ) assert hdrs.CONTENT_TYPE in req.headers assert req.headers[hdrs.CONTENT_TYPE] == "text/plain" assert "gzip" in req.headers[hdrs.ACCEPT_ENCODING] async def test_headers_list(make_client_request: _RequestMaker) -> None: req = make_client_request( "post", URL("http://python.org/"), headers=CIMultiDict((("Content-Type", "text/plain"),)), ) assert "CONTENT-TYPE" in req.headers assert req.headers["CONTENT-TYPE"] == "text/plain" async def test_headers_default(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL("http://python.org/"), headers=CIMultiDict({"ACCEPT-ENCODING": "deflate"}), ) assert req.headers["ACCEPT-ENCODING"] == "deflate" async def test_invalid_url(make_client_request: _RequestMaker) -> None: with pytest.raises(aiohttp.InvalidURL): make_client_request("get", URL("hiwpefhipowhefopw")) async def test_no_path(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org")) assert "/" == req.url.path async def test_ipv6_default_http_port(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://[2001:db8::1]/")) assert req.url.host == "2001:db8::1" assert req.url.port == 80 assert not req.is_ssl() async def test_ipv6_default_https_port(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("https://[2001:db8::1]/")) assert req.url.host == "2001:db8::1" assert req.url.port == 443 assert req.is_ssl() async def test_ipv6_nondefault_http_port(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://[2001:db8::1]:960/")) assert req.url.host == "2001:db8::1" assert req.url.port == 960 assert not req.is_ssl() async def test_ipv6_nondefault_https_port(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("https://[2001:db8::1]:960/")) assert req.url.host == "2001:db8::1" assert req.url.port == 960 assert req.is_ssl() async def test_basic_auth(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL("http://python.org"), auth=aiohttp.BasicAuth("nkim", "1234") ) assert "AUTHORIZATION" in req.headers assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] async def test_basic_auth_utf8(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL("http://python.org"), auth=aiohttp.BasicAuth("nkim", "секрет", "utf-8"), ) assert "AUTHORIZATION" in req.headers assert "Basic bmtpbTrRgdC10LrRgNC10YI=" == req.headers["AUTHORIZATION"] async def test_basic_auth_tuple_forbidden(make_client_request: _RequestMaker) -> None: with pytest.raises(TypeError): make_client_request("get", URL("http://python.org"), auth=("nkim", "1234")) # type: ignore[arg-type] async def test_basic_auth_from_url(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://nkim:1234@python.org")) assert "AUTHORIZATION" in req.headers assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] assert "python.org" == req.url.host async def test_basic_auth_no_user_from_url(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://:1234@python.org")) assert "AUTHORIZATION" in req.headers assert "Basic OjEyMzQ=" == req.headers["AUTHORIZATION"] assert "python.org" == req.url.host async def test_basic_auth_from_url_overridden( make_client_request: _RequestMaker, ) -> None: req = make_client_request( "get", URL("http://garbage@python.org"), auth=aiohttp.BasicAuth("nkim", "1234") ) assert "AUTHORIZATION" in req.headers assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] assert "python.org" == req.url.host async def test_path_is_not_double_encoded1(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://0.0.0.0/get/test case")) assert req.url.raw_path == "/get/test%20case" async def test_path_is_not_double_encoded2(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://0.0.0.0/get/test%2fcase")) assert req.url.raw_path == "/get/test%2Fcase" async def test_path_is_not_double_encoded3(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://0.0.0.0/get/test%20case")) assert req.url.raw_path == "/get/test%20case" async def test_path_safe_chars_preserved(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://0.0.0.0/get/:=+/%2B/")) assert req.url.path == "/get/:=+/+/" async def test_params_are_added_before_fragment1( make_client_request: _RequestMaker, ) -> None: req = make_client_request( "GET", URL("http://example.com/path#fragment"), params={"a": "b"} ) assert str(req.url) == "http://example.com/path?a=b" async def test_params_are_added_before_fragment2( make_client_request: _RequestMaker, ) -> None: req = make_client_request( "GET", URL("http://example.com/path?key=value#fragment"), params={"a": "b"} ) assert str(req.url) == "http://example.com/path?key=value&a=b" async def test_path_not_contain_fragment1(make_client_request: _RequestMaker) -> None: req = make_client_request("GET", URL("http://example.com/path#fragment")) assert req.url.path == "/path" async def test_path_not_contain_fragment2(make_client_request: _RequestMaker) -> None: req = make_client_request("GET", URL("http://example.com/path?key=value#fragment")) assert str(req.url) == "http://example.com/path?key=value" async def test_cookies(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL("http://test.com/path"), cookies=BaseCookie({"cookie1": "val1"}) ) assert "COOKIE" in req.headers assert "cookie1=val1" == req.headers["COOKIE"] async def test_cookies_merge_with_headers(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL("http://test.com/path"), headers=CIMultiDict({"cookie": "cookie1=val1"}), cookies=BaseCookie({"cookie2": "val2"}), ) assert "cookie1=val1; cookie2=val2" == req.headers["COOKIE"] async def test_query_multivalued_param(make_client_request: _RequestMaker) -> None: for meth in ALL_METHODS: req = make_client_request( meth, URL("http://python.org"), params=(("test", "foo"), ("test", "baz")) ) assert str(req.url) == "http://python.org/?test=foo&test=baz" async def test_query_str_param(make_client_request: _RequestMaker) -> None: for meth in ALL_METHODS: req = make_client_request(meth, URL("http://python.org"), params="test=foo") assert str(req.url) == "http://python.org/?test=foo" async def test_query_bytes_param_raises(make_client_request: _RequestMaker) -> None: for meth in ALL_METHODS: with pytest.raises(TypeError): make_client_request(meth, URL("http://python.org"), params=b"test=foo") # type: ignore[arg-type] async def test_query_str_param_is_not_encoded( make_client_request: _RequestMaker, ) -> None: for meth in ALL_METHODS: req = make_client_request(meth, URL("http://python.org"), params="test=f+oo") assert str(req.url) == "http://python.org/?test=f+oo" async def test_params_update_path_and_url(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL("http://python.org"), params=(("test", "foo"), ("test", "baz")) ) assert str(req.url) == "http://python.org/?test=foo&test=baz" async def test_params_empty_path_and_url(make_client_request: _RequestMaker) -> None: req_empty = make_client_request("get", URL("http://python.org"), params={}) assert str(req_empty.url) == "http://python.org" req_none = make_client_request("get", URL("http://python.org")) assert str(req_none.url) == "http://python.org" async def test_gen_netloc_all(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL( "https://aiohttp:pwpwpw@12345678901234567890123456789012345678901234567890:8080" ), ) assert ( req.headers["HOST"] == "12345678901234567890123456789" + "012345678901234567890:8080" ) async def test_gen_netloc_no_port(make_client_request: _RequestMaker) -> None: req = make_client_request( "get", URL( "https://aiohttp:pwpwpw@12345678901234567890123456789012345678901234567890/" ), ) assert ( req.headers["HOST"] == "12345678901234567890123456789" + "012345678901234567890" ) async def test_cookie_coded_value_preserved( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Verify the coded value of a cookie is preserved.""" # https://github.com/aio-libs/aiohttp/pull/1453 req = make_client_request("get", URL("http://python.org"), loop=loop) req._update_cookies(cookies=SimpleCookie('ip-cookie="second"; Domain=127.0.0.1;')) assert req.headers["COOKIE"] == 'ip-cookie="second"' async def test_update_cookies_with_special_chars_in_existing_header( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that update_cookies handles existing cookies with special characters.""" # Create request with a cookie that has special characters (real-world example) req = make_client_request( "get", URL("http://python.org"), headers=CIMultiDict( {"Cookie": "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=value1"} ), loop=loop, ) # Update with another cookie req._update_cookies(cookies=BaseCookie({"normal_cookie": "value2"})) # Both cookies should be preserved in the exact order assert ( req.headers["COOKIE"] == "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=value1; normal_cookie=value2" ) async def test_update_cookies_with_quoted_existing_header( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that update_cookies handles existing cookies with quoted values.""" # Create request with cookies that have quoted values req = make_client_request( "get", URL("http://python.org"), headers=CIMultiDict({"Cookie": 'session="value;with;semicolon"; token=abc123'}), loop=loop, ) # Update with another cookie req._update_cookies(cookies=BaseCookie({"new_cookie": "new_value"})) # All cookies should be preserved with their original coded values # The quoted value should be preserved as-is assert ( req.headers["COOKIE"] == 'new_cookie=new_value; session="value;with;semicolon"; token=abc123' ) async def test_connection_header( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request("get", URL("http://python.org"), loop=loop) req.headers.clear() req.version = HttpVersion11 req.headers.clear() with mock.patch.object(conn._connector, "force_close", False): await req._send(conn) assert req.headers.get("CONNECTION") is None req.version = HttpVersion10 req.headers.clear() with mock.patch.object(conn._connector, "force_close", False): await req._send(conn) assert req.headers.get("CONNECTION") == "keep-alive" req.version = HttpVersion11 req.headers.clear() with mock.patch.object(conn._connector, "force_close", True): await req._send(conn) assert req.headers.get("CONNECTION") == "close" req.version = HttpVersion10 req.headers.clear() with mock.patch.object(conn._connector, "force_close", True): await req._send(conn) assert not req.headers.get("CONNECTION") async def test_no_content_length( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request("get", URL("http://python.org"), loop=loop) resp = await req._send(conn) assert req.headers.get("CONTENT-LENGTH") is None await req._close() resp.close() async def test_no_content_length_head( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request("head", URL("http://python.org"), loop=loop) resp = await req._send(conn) assert req.headers.get("CONTENT-LENGTH") is None await req._close() resp.close() async def test_content_type_auto_header_get( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request("get", URL("http://python.org"), loop=loop) resp = await req._send(conn) assert "CONTENT-TYPE" not in req.headers resp.close() await req._close() async def test_content_type_auto_header_form( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org"), data={"hey": "you"}, loop=loop ) resp = await req._send(conn) assert "application/x-www-form-urlencoded" == req.headers.get("CONTENT-TYPE") resp.close() async def test_content_type_auto_header_bytes( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org"), data=b"hey you", loop=loop ) resp = await req._send(conn) assert "application/octet-stream" == req.headers.get("CONTENT-TYPE") resp.close() async def test_content_type_skip_auto_header_bytes( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org"), data=b"hey you", skip_auto_headers={"Content-Type"}, loop=loop, ) assert req.skip_auto_headers == CIMultiDict({"CONTENT-TYPE": None}) resp = await req._send(conn) assert "CONTENT-TYPE" not in req.headers resp.close() async def test_content_type_skip_auto_header_form( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org"), data={"hey": "you"}, loop=loop, skip_auto_headers={"Content-Type"}, ) resp = await req._send(conn) assert "CONTENT-TYPE" not in req.headers resp.close() async def test_content_type_auto_header_content_length_no_skip( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: with io.BytesIO(b"hey") as file_handle: req = make_client_request( "post", URL("http://python.org"), data=file_handle, skip_auto_headers={"Content-Length"}, loop=loop, ) resp = await req._send(conn) assert req.headers.get("CONTENT-LENGTH") == "3" resp.close() async def test_urlencoded_formdata_charset( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org"), data=aiohttp.FormData({"hey": "you"}, charset="koi8-r"), loop=loop, ) async with await req._send(conn): await asyncio.sleep(0) assert "application/x-www-form-urlencoded; charset=koi8-r" == req.headers.get( "CONTENT-TYPE" ) async def test_formdata_boundary_from_headers( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: boundary = "some_boundary" file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: req = make_client_request( "post", URL("http://python.org"), data={"aiohttp.png": f}, headers=CIMultiDict( {"Content-Type": f"multipart/form-data; boundary={boundary}"} ), loop=loop, ) async with await req._send(conn): await asyncio.sleep(0) assert isinstance(req.body, MultipartWriter) assert req.body._boundary == boundary.encode() async def test_post_data( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: for meth in ClientRequest.POST_METHODS: req = make_client_request( meth, URL("http://python.org/"), data={"life": "42"}, loop=loop ) resp = await req._send(conn) assert "/" == req.url.path assert isinstance(req.body, payload.Payload) assert b"life=42" == req.body._value assert "application/x-www-form-urlencoded" == req.headers["CONTENT-TYPE"] await req._close() resp.close() async def test_pass_falsy_data( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: with mock.patch("aiohttp.client_reqrep.ClientRequest._update_body_from_data") as m: req = make_client_request("post", URL("http://python.org/"), data={}, loop=loop) m.assert_called_once_with({}) await req._close() async def test_pass_falsy_data_file( loop: asyncio.AbstractEventLoop, tmp_path: pathlib.Path, make_client_request: _RequestMaker, ) -> None: testfile = (tmp_path / "tmpfile").open("w+b") testfile.write(b"data") testfile.seek(0) skip = frozenset([hdrs.CONTENT_TYPE]) req = make_client_request( "post", URL("http://python.org/"), data=testfile, skip_auto_headers=skip, loop=loop, ) assert req.headers.get("CONTENT-LENGTH", None) is not None await req._close() testfile.close() # Elasticsearch API requires to send request body with GET-requests async def test_get_with_data( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: for meth in ClientRequest.GET_METHODS: req = make_client_request( meth, URL("http://python.org/"), data={"life": "42"}, loop=loop ) assert "/" == req.url.path assert isinstance(req.body, payload.Payload) assert b"life=42" == req.body._value await req._close() async def test_bytes_data( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: for meth in ClientRequest.POST_METHODS: req = make_client_request( meth, URL("http://python.org/"), data=b"binary data", loop=loop ) resp = await req._send(conn) assert "/" == req.url.path assert isinstance(req.body, payload.BytesPayload) assert b"binary data" == req.body._value assert "application/octet-stream" == req.headers["CONTENT-TYPE"] await req._close() resp.close() @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_content_encoding( # type: ignore[misc] loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org/"), data="foo", compress="deflate", loop=loop ) with mock.patch( "aiohttp.client_reqrep.StreamWriter", autospec=True, spec_set=True ) as m_writer: resp = await req._send(conn) assert req.headers["TRANSFER-ENCODING"] == "chunked" assert req.headers["CONTENT-ENCODING"] == "deflate" m_writer.return_value.enable_compression.assert_called_with("deflate") await req._close() resp.close() async def test_content_encoding_dont_set_headers_if_no_body( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org/"), compress="deflate", loop=loop ) resp = await req._send(conn) assert "TRANSFER-ENCODING" not in req.headers assert "CONTENT-ENCODING" not in req.headers await req._close() resp.close() @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_content_encoding_header( # type: ignore[misc] loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org/"), data="foo", headers=CIMultiDict({"Content-Encoding": "deflate"}), loop=loop, ) with mock.patch( "aiohttp.client_reqrep.StreamWriter", autospec=True, spec_set=True ) as m_writer: resp = await req._send(conn) assert not m_writer.return_value.enable_compression.called assert not m_writer.return_value.enable_chunking.called await req._close() resp.close() async def test_compress_and_content_encoding( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: with pytest.raises(ValueError): make_client_request( "post", URL("http://python.org/"), data="foo", headers=CIMultiDict({"content-encoding": "deflate"}), compress="deflate", loop=loop, ) async def test_chunked( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org/"), headers=CIMultiDict({"TRANSFER-ENCODING": "gzip"}), loop=loop, ) resp = await req._send(conn) assert "gzip" == req.headers["TRANSFER-ENCODING"] await req._close() resp.close() async def test_chunked2( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org/"), headers=CIMultiDict({"Transfer-encoding": "chunked"}), loop=loop, ) resp = await req._send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] await req._close() resp.close() async def test_chunked_empty_body( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: """Ensure write_bytes is called even if the body is empty.""" req = make_client_request( "post", URL("http://python.org/"), chunked=True, loop=loop, data=b"", ) with mock.patch.object(req, "_write_bytes") as write_bytes: resp = await req._send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] assert write_bytes.called await req._close() resp.close() async def test_chunked_explicit( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "post", URL("http://python.org/"), chunked=True, loop=loop ) with mock.patch( "aiohttp.client_reqrep.StreamWriter", autospec=True, spec_set=True ) as m_writer: resp = await req._send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] m_writer.return_value.enable_chunking.assert_called_with() await req._close() resp.close() async def test_chunked_length( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: with pytest.raises(ValueError): make_client_request( "post", URL("http://python.org/"), headers=CIMultiDict({"CONTENT-LENGTH": "1000"}), chunked=True, loop=loop, ) async def test_chunked_transfer_encoding( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: with pytest.raises(ValueError): make_client_request( "post", URL("http://python.org/"), headers=CIMultiDict({"TRANSFER-ENCODING": "chunked"}), chunked=True, loop=loop, ) async def test_file_upload_not_chunked( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: req = make_client_request("post", URL("http://python.org/"), data=f, loop=loop) assert not req.chunked assert req.headers["CONTENT-LENGTH"] == str(file_path.stat().st_size) await req._close() @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_precompressed_data_stays_intact( # type: ignore[misc] loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: data = ZLibBackend.compress(b"foobar") req = make_client_request( "post", URL("http://python.org/"), data=data, headers=CIMultiDict({"CONTENT-ENCODING": "deflate"}), compress=False, loop=loop, ) assert not req.compress assert not req.chunked assert req.headers["CONTENT-ENCODING"] == "deflate" await req._close() async def test_body_with_size_sets_content_length( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that when body has a size and no Content-Length header is set, it gets added.""" # Create a BytesPayload which has a size property data = b"test data" # Create request with data that will create a BytesPayload req = make_client_request( "post", URL("http://python.org/"), data=data, loop=loop, ) # Verify Content-Length was set from body.size assert req.headers["CONTENT-LENGTH"] == str(len(data)) assert req.body is not None assert req._body is not None # When _body is set, body returns it assert req._body.size == len(data) await req._close() async def test_body_payload_with_size_no_content_length( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that when a body payload is set via update_body, Content-Length is added.""" # Create a payload with a known size data = b"payload data" bytes_payload = payload.BytesPayload(data) # Create request with no data initially req = make_client_request( "post", URL("http://python.org/"), loop=loop, ) # POST method with None body should have Content-Length: 0 assert req.headers[hdrs.CONTENT_LENGTH] == "0" # Update body using the public method await req.update_body(bytes_payload) # Verify Content-Length was set from body.size assert req.headers[hdrs.CONTENT_LENGTH] == str(len(data)) assert req.body is bytes_payload assert req._body is bytes_payload # Access _body which is the Payload assert req._body.size == len(data) # Set body back to None await req.update_body(None) # Verify Content-Length is back to 0 for POST with None body assert req.headers[hdrs.CONTENT_LENGTH] == "0" await req._close() async def test_file_upload_not_chunked_seek( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: f.seek(100) req = make_client_request("post", URL("http://python.org/"), data=f, loop=loop) assert req.headers["CONTENT-LENGTH"] == str(file_path.stat().st_size - 100) await req._close() async def test_file_upload_force_chunked( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: req = make_client_request( "post", URL("http://python.org/"), data=f, chunked=True, loop=loop ) assert req.chunked assert "CONTENT-LENGTH" not in req.headers await req._close() async def test_expect100( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "get", URL("http://python.org/"), expect100=True, loop=loop ) resp = await req._send(conn) assert "100-continue" == req.headers["EXPECT"] assert req._continue is not None req._terminate() resp.close() async def test_expect_100_continue_header( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "get", URL("http://python.org/"), headers=CIMultiDict({"expect": "100-continue"}), loop=loop, ) resp = await req._send(conn) assert "100-continue" == req.headers["EXPECT"] assert req._continue is not None req._terminate() resp.close() async def test_data_stream( loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: async def gen() -> AsyncIterator[bytes]: yield b"binary data" yield b" result" req = make_client_request("POST", URL("http://python.org/"), data=gen(), loop=loop) assert req.chunked assert req.headers["TRANSFER-ENCODING"] == "chunked" original_write_bytes = req._write_bytes async def _mock_write_bytes( writer: AbstractStreamWriter, conn: mock.Mock, content_length: int | None ) -> None: # Ensure the task is scheduled await asyncio.sleep(0) await original_write_bytes(writer, conn, content_length) with mock.patch.object(req, "_write_bytes", _mock_write_bytes): resp = await req._send(conn) assert asyncio.isfuture(req._writer) await resp.wait_for_close() assert req._writer is None assert ( # type: ignore[unreachable] buf.split(b"\r\n\r\n", 1)[1] == b"b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n" ) await req._close() async def test_data_file( loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: with io.BufferedReader(io.BytesIO(b"*" * 2)) as file_handle: req = make_client_request( "POST", URL("http://python.org/"), data=file_handle, loop=loop, ) assert req.chunked assert isinstance(req.body, payload.BufferedReaderPayload) assert req.headers["TRANSFER-ENCODING"] == "chunked" original_write_bytes = req._write_bytes async def _mock_write_bytes( writer: AbstractStreamWriter, conn: mock.Mock, content_length: int | None ) -> None: # Ensure the task is scheduled so _writer isn't None await asyncio.sleep(0) await original_write_bytes(writer, conn, content_length) with mock.patch.object(req, "_write_bytes", _mock_write_bytes): resp = await req._send(conn) assert asyncio.isfuture(req._writer) await resp.wait_for_close() assert req._writer is None assert buf.split(b"\r\n\r\n", 1)[1] == b"2\r\n" + b"*" * 2 + b"\r\n0\r\n\r\n" # type: ignore[unreachable] await req._close() async def test_data_stream_exc( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: fut = loop.create_future() async def gen() -> AsyncIterator[bytes]: yield b"binary data" await fut req = make_client_request("POST", URL("http://python.org/"), data=gen(), loop=loop) assert req.chunked assert req.headers["TRANSFER-ENCODING"] == "chunked" async def throw_exc() -> None: await asyncio.sleep(0.01) fut.set_exception(ValueError) t = loop.create_task(throw_exc()) async with await req._send(conn): assert req._writer is not None await req._writer await t # assert conn.close.called assert conn.protocol is not None assert conn.protocol.set_exception.called await req._close() async def test_data_stream_exc_chain( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: fut = loop.create_future() async def gen() -> AsyncIterator[None]: await fut assert False yield # type: ignore[unreachable] # pragma: no cover req = make_client_request("POST", URL("http://python.org/"), data=gen(), loop=loop) inner_exc = ValueError() async def throw_exc() -> None: await asyncio.sleep(0.01) fut.set_exception(inner_exc) t = loop.create_task(throw_exc()) async with await req._send(conn): assert req._writer is not None await req._writer await t # assert conn.close.called assert conn.protocol.set_exception.called outer_exc = conn.protocol.set_exception.call_args[0][0] assert isinstance(outer_exc, ClientConnectionError) assert outer_exc.__cause__ is inner_exc await req._close() async def test_data_stream_continue( loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: async def gen() -> AsyncIterator[bytes]: yield b"binary data" yield b" result" req = make_client_request( "POST", URL("http://python.org/"), data=gen(), expect100=True, loop=loop ) assert req.chunked async def coro() -> None: await asyncio.sleep(0.0001) assert req._continue is not None req._continue.set_result(1) t = loop.create_task(coro()) resp = await req._send(conn) assert req._writer is not None await req._writer await t assert ( buf.split(b"\r\n\r\n", 1)[1] == b"b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n" ) await req._close() resp.close() async def test_data_continue( loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "POST", URL("http://python.org/"), data=b"data", expect100=True, loop=loop ) async def coro() -> None: await asyncio.sleep(0.0001) assert req._continue is not None req._continue.set_result(1) t = loop.create_task(coro()) resp = await req._send(conn) assert req._writer is not None await req._writer await t assert buf.split(b"\r\n\r\n", 1)[1] == b"data" await req._close() resp.close() async def test_close( loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: async def gen() -> AsyncIterator[bytes]: await asyncio.sleep(0.00001) yield b"result" req = make_client_request("POST", URL("http://python.org/"), data=gen(), loop=loop) resp = await req._send(conn) await req._close() assert buf.split(b"\r\n\r\n", 1)[1] == b"6\r\nresult\r\n0\r\n\r\n" await req._close() resp.close() async def test_bad_version( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "GET", URL("http://python.org"), loop=loop, headers=CIMultiDict({"Connection": "Close"}), version=("1", "1\r\nInjected-Header: not allowed"), # type: ignore[arg-type] ) with pytest.raises(AttributeError): await req._send(conn) async def test_custom_response_class( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: class CustomResponse(ClientResponse): async def read(self) -> bytes: return b"customized!" req = make_client_request( "GET", URL("http://python.org/"), response_class=CustomResponse, loop=loop ) resp = await req._send(conn) assert await resp.read() == b"customized!" await req._close() resp.close() async def test_oserror_on_write_bytes( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request("POST", URL("http://python.org/"), loop=loop) await req.update_body(b"test data") writer = WriterMock() writer.write.side_effect = OSError await req._write_bytes(writer, conn, None) assert conn.protocol.set_exception.called exc = conn.protocol.set_exception.call_args[0][0] assert isinstance(exc, aiohttp.ClientOSError) @pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()") async def test_cancel_close( # type: ignore[misc] loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request("get", URL("http://python.org"), loop=loop) req._writer = asyncio.Future() # type: ignore[assignment] t = asyncio.create_task(req._close()) # Start waiting on _writer await asyncio.sleep(0) t.cancel() # Cancellation should not be suppressed. with pytest.raises(asyncio.CancelledError): await t async def test_terminate( loop: asyncio.AbstractEventLoop, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: req = make_client_request("get", URL("http://python.org"), loop=loop) async def _mock_write_bytes(*args: object, **kwargs: object) -> None: # Ensure the task is scheduled await asyncio.sleep(0) with mock.patch.object(req, "_write_bytes", _mock_write_bytes): resp = await req._send(conn) assert req._writer is not None assert resp._writer is not None await resp._writer writer = WriterMock() writer.done = mock.Mock(return_value=False) writer.cancel = mock.Mock() req._writer = writer resp._writer = writer assert req._writer is not None assert resp._writer is not None req._terminate() writer.cancel.assert_called_with() writer.done.assert_called_with() resp.close() def test_terminate_with_closed_loop( loop: asyncio.AbstractEventLoop, conn: mock.Mock, ) -> None: req = resp = writer = None async def go() -> None: nonlocal req, resp, writer # Can't use make_client_request here, due to closing the loop mid-test. req = ClientRequest( "get", URL("http://python.org"), loop=loop, params={}, headers=CIMultiDict[str](), skip_auto_headers=None, data=None, cookies=BaseCookie[str](), auth=None, version=HttpVersion11, compress=False, chunked=None, expect100=False, response_class=ClientResponse, proxy=None, proxy_auth=None, timer=TimerNoop(), session=None, # type: ignore[arg-type] ssl=True, proxy_headers=None, traces=[], trust_env=False, server_hostname=None, ) async def _mock_write_bytes(*args: object, **kwargs: object) -> None: # Ensure the task is scheduled await asyncio.sleep(0) with mock.patch.object(req, "_write_bytes", _mock_write_bytes): resp = await req._send(conn) assert req._writer is not None writer = WriterMock() writer.done = mock.Mock(return_value=False) req._writer = writer resp._writer = writer await asyncio.sleep(0.05) loop.run_until_complete(go()) loop.close() assert req is not None req._terminate() assert req._writer is None assert writer is not None assert not writer.cancel.called assert resp is not None resp.close() async def test_terminate_without_writer(make_client_request: _RequestMaker) -> None: req = make_client_request("get", URL("http://python.org")) assert req._writer is None req._terminate() assert req._writer is None async def test_custom_req_rep( loop: asyncio.AbstractEventLoop, create_mocked_conn: mock.Mock ) -> None: conn = None class CustomResponse(ClientResponse): async def start(self, connection: Connection) -> ClientResponse: nonlocal conn conn = connection self.status = 123 self.reason = "Test OK" self._headers = CIMultiDictProxy(CIMultiDict()) self.cookies = SimpleCookie() return self called = False class CustomRequest(ClientRequest): async def _send(self, conn: Connection) -> ClientResponse: resp = self.response_class( self.method, self.url, writer=self._writer, continue100=self._continue, timer=self._timer, traces=self._traces, loop=self.loop, session=self._session, request_headers=self.headers, original_url=self.original_url, ) self.response = resp nonlocal called called = True return resp async def create_connection( req: ClientRequest, traces: object, timeout: object ) -> Connection: assert isinstance(req, CustomRequest) return create_mocked_conn() # type: ignore[no-any-return] connector = BaseConnector() with mock.patch.object(connector, "_create_connection", create_connection): session = aiohttp.ClientSession( request_class=CustomRequest, response_class=CustomResponse, connector=connector, ) resp = await session.request("get", URL("http://example.com/path/to")) assert isinstance(resp, CustomResponse) assert called resp.close() await session.close() assert conn is not None conn.close() def test_bad_fingerprint(loop: asyncio.AbstractEventLoop) -> None: with pytest.raises(ValueError): Fingerprint(b"invalid") def test_insecure_fingerprint_md5(loop: asyncio.AbstractEventLoop) -> None: with pytest.raises(ValueError): Fingerprint(hashlib.md5(b"foo").digest()) def test_insecure_fingerprint_sha1(loop: asyncio.AbstractEventLoop) -> None: with pytest.raises(ValueError): Fingerprint(hashlib.sha1(b"foo").digest()) @pytest.mark.parametrize( "has_brotli,has_zstd,expected", [ (False, False, "gzip, deflate"), (True, False, "gzip, deflate, br"), (False, True, "gzip, deflate, zstd"), (True, True, "gzip, deflate, br, zstd"), ], ) def test_gen_default_accept_encoding( has_brotli: bool, has_zstd: bool, expected: str ) -> None: with mock.patch("aiohttp.client_reqrep.HAS_BROTLI", has_brotli): with mock.patch("aiohttp.client_reqrep.HAS_ZSTD", has_zstd): assert _gen_default_accept_encoding() == expected @pytest.mark.parametrize( "netrc_contents", ("machine example.com login username password pass\n",), indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") async def test_basicauth_from_netrc_present_untrusted_env( # type: ignore[misc] make_client_request: _RequestMaker, ) -> None: """Test no authorization header is sent via netrc if trust_env is False""" req = make_client_request("get", URL("http://example.com"), trust_env=False) assert hdrs.AUTHORIZATION not in req.headers @pytest.mark.parametrize( "netrc_contents", ("",), indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") async def test_basicauth_from_empty_netrc( # type: ignore[misc] make_client_request: _RequestMaker, ) -> None: """Test that no Authorization header is sent when netrc is empty""" req = make_client_request("get", URL("http://example.com"), trust_env=True) assert hdrs.AUTHORIZATION not in req.headers async def test_connection_key_with_proxy( make_client_request: _RequestMaker, ) -> None: """Verify the proxy headers are included in the ConnectionKey when a proxy is used.""" proxy = URL("http://proxy.example.com") req = make_client_request( "GET", URL("http://example.com"), proxy=proxy, proxy_headers=CIMultiDict({"X-Proxy": "true"}), loop=asyncio.get_running_loop(), ) assert req.connection_key.proxy_headers_hash is not None await req._close() async def test_connection_key_without_proxy( make_client_request: _RequestMaker, ) -> None: """Verify the proxy headers are not included in the ConnectionKey when a proxy is used.""" # If proxy is unspecified, proxy_headers should be ignored req = make_client_request( "GET", URL("http://example.com"), proxy_headers=CIMultiDict({"X-Proxy": "true"}), loop=asyncio.get_running_loop(), ) assert req.connection_key.proxy_headers_hash is None await req._close() def test_request_info_back_compat() -> None: """Test RequestInfo can be created without real_url.""" url = URL("http://example.com") other_url = URL("http://example.org") assert ( aiohttp.RequestInfo( url=url, method="GET", headers=CIMultiDictProxy(CIMultiDict()) ).real_url is url ) assert ( aiohttp.RequestInfo(url, "GET", CIMultiDictProxy(CIMultiDict())).real_url is url ) assert ( aiohttp.RequestInfo( url, "GET", CIMultiDictProxy(CIMultiDict()), real_url=url ).real_url is url ) assert ( aiohttp.RequestInfo( url, "GET", CIMultiDictProxy(CIMultiDict()), real_url=other_url ).real_url is other_url ) def test_request_info_tuple_new() -> None: """Test RequestInfo must be created with real_url using tuple.__new__.""" url = URL("http://example.com") with pytest.raises(IndexError): tuple.__new__( aiohttp.RequestInfo, (url, "GET", CIMultiDictProxy(CIMultiDict())) ).real_url assert ( tuple.__new__( aiohttp.RequestInfo, (url, "GET", CIMultiDictProxy(CIMultiDict()), url) ).real_url is url ) async def test_get_content_length(make_client_request: _RequestMaker) -> None: """Test _get_content_length method extracts Content-Length correctly.""" req = make_client_request("get", URL("http://python.org/")) # No Content-Length header assert req._get_content_length() is None # Valid Content-Length header req.headers["Content-Length"] = "42" assert req._get_content_length() == 42 # Invalid Content-Length header req.headers["Content-Length"] = "invalid" with pytest.raises(ValueError, match="Invalid Content-Length header: invalid"): req._get_content_length() async def test_write_bytes_with_content_length_limit( loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: """Test that write_bytes respects content_length limit for different body types.""" # Test with bytes data data = b"Hello World" req = make_client_request("post", URL("http://python.org/"), loop=loop) await req.update_body(data) writer = StreamWriter(protocol=conn.protocol, loop=loop) # Use content_length=5 to truncate data await req._write_bytes(writer, conn, 5) # Verify only the first 5 bytes were written assert buf == b"Hello" await req._close() @pytest.mark.parametrize( "data", [ [b"Part1", b"Part2", b"Part3"], b"Part1Part2Part3", ], ) async def test_write_bytes_with_iterable_content_length_limit( # type: ignore[misc] loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, data: list[bytes] | bytes, make_client_request: _RequestMaker, ) -> None: """Test that write_bytes respects content_length limit for iterable data.""" # Test with iterable data req = make_client_request("post", URL("http://python.org/"), loop=loop) # Convert list to async generator if needed if isinstance(data, list): async def gen() -> AsyncIterator[bytes]: for chunk in data: yield chunk await req.update_body(gen()) else: await req.update_body(data) writer = StreamWriter(protocol=conn.protocol, loop=loop) # Use content_length=7 to truncate at the middle of Part2 await req._write_bytes(writer, conn, 7) assert len(buf) == 7 assert buf == b"Part1Pa" await req._close() async def test_write_bytes_empty_iterable_with_content_length( loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, make_client_request: _RequestMaker, ) -> None: """Test that write_bytes handles empty iterable body with content_length.""" req = make_client_request("post", URL("http://python.org/"), loop=loop) # Create an empty async generator async def gen() -> AsyncIterator[bytes]: return yield # pragma: no cover # This makes it a generator but never executes await req.update_body(gen()) writer = StreamWriter(protocol=conn.protocol, loop=loop) # Use content_length=10 with empty body await req._write_bytes(writer, conn, 10) # Verify nothing was written assert len(buf) == 0 await req._close() async def test_update_body_closes_previous_payload( make_client_request: _RequestMaker, ) -> None: """Test that update_body properly closes the previous payload.""" req = make_client_request("POST", URL("http://python.org/")) # Create a mock payload that tracks if it was closed mock_payload = mock.create_autospec(payload.Payload, spec_set=True, instance=True) # Set initial payload req._body = mock_payload # Update body with new data await req.update_body(b"new body data") # Verify the previous payload was closed mock_payload.close.assert_called_once() # Verify new body is set (it's a BytesPayload now) assert isinstance(req.body, payload.BytesPayload) await req._close() async def test_update_body_with_different_types( make_client_request: _RequestMaker, ) -> None: """Test update_body with various data types.""" req = make_client_request("POST", URL("http://python.org/")) # Test with bytes await req.update_body(b"bytes data") assert isinstance(req.body, payload.BytesPayload) # Test with string await req.update_body("string data") assert isinstance(req.body, payload.BytesPayload) # Test with None (clears body) await req.update_body(None) assert req.body._value == b"" await req._close() async def test_update_body_with_chunked_encoding( make_client_request: _RequestMaker, ) -> None: """Test that update_body properly handles chunked transfer encoding.""" # Create request with chunked=True req = make_client_request("POST", URL("http://python.org/"), chunked=True) # Verify Transfer-Encoding header is set assert req.headers["Transfer-Encoding"] == "chunked" assert "Content-Length" not in req.headers # Update body - should maintain chunked encoding await req.update_body(b"chunked data") assert req.headers["Transfer-Encoding"] == "chunked" assert "Content-Length" not in req.headers assert isinstance(req.body, payload.BytesPayload) # Update with different body - chunked should remain await req.update_body(b"different chunked data") assert req.headers["Transfer-Encoding"] == "chunked" assert "Content-Length" not in req.headers # Clear body - chunked header should remain await req.update_body(None) assert req.headers["Transfer-Encoding"] == "chunked" assert "Content-Length" not in req.headers await req._close() async def test_update_body_get_method_with_none_body( make_client_request: _RequestMaker, ) -> None: """Test that update_body with GET method and None body doesn't call update_transfer_encoding.""" # Create GET request req = make_client_request("GET", URL("http://python.org/")) # GET requests shouldn't have Transfer-Encoding or Content-Length initially assert "Transfer-Encoding" not in req.headers assert "Content-Length" not in req.headers # Update body to None - should not trigger update_transfer_encoding # This covers the branch where body is None AND method is in GET_METHODS await req.update_body(None) # Headers should remain unchanged assert "Transfer-Encoding" not in req.headers assert "Content-Length" not in req.headers await req._close() async def test_update_body_updates_content_length( make_client_request: _RequestMaker, ) -> None: """Test that update_body properly updates Content-Length header when body size changes.""" req = make_client_request("POST", URL("http://python.org/")) # Set initial body with known size await req.update_body(b"initial data") initial_content_length = req.headers.get("Content-Length") assert initial_content_length == "12" # len(b"initial data") = 12 # Update body with different size await req.update_body(b"much longer data than before") new_content_length = req.headers.get("Content-Length") assert new_content_length == "28" # len(b"much longer data than before") = 28 # Update body with shorter data await req.update_body(b"short") assert req.headers.get("Content-Length") == "5" # len(b"short") = 5 # Clear body await req.update_body(None) # For None body with POST method, Content-Length should be set to 0 assert req.headers[hdrs.CONTENT_LENGTH] == "0" await req._close() async def test_expect100_with_body_becomes_empty( make_client_request: _RequestMaker, ) -> None: """Test that write_bytes handles body becoming empty after expect100 handling.""" # Create a mock writer and connection mock_writer = mock.create_autospec(StreamWriter, instance=True, spec_set=True) mock_conn = mock.Mock() # Create a request req = make_client_request( "POST", URL("http://test.example.com/"), loop=asyncio.get_event_loop() ) req._body = mock.Mock() # Start with a body # Now set body to empty payload to simulate a race condition # where req._body is set to None after expect100 handling req._body = payload.PAYLOAD_REGISTRY.get(b"", disposition=None) await req._write_bytes(mock_writer, mock_conn, None) @pytest.mark.parametrize( ("method", "data", "expected_content_length"), [ # GET methods should not have Content-Length with None body ("GET", None, None), ("HEAD", None, None), ("OPTIONS", None, None), ("TRACE", None, None), # POST methods should have Content-Length: 0 with None body ("POST", None, "0"), ("PUT", None, "0"), ("PATCH", None, "0"), ("DELETE", None, "0"), # Empty bytes should always set Content-Length: 0 ("GET", b"", "0"), ("HEAD", b"", "0"), ("POST", b"", "0"), ("PUT", b"", "0"), # Non-empty bytes should set appropriate Content-Length ("GET", b"test", "4"), ("POST", b"test", "4"), ("PUT", b"hello world", "11"), ("PATCH", b"data", "4"), ("DELETE", b"x", "1"), ], ) async def test_content_length_for_methods( # type: ignore[misc] method: str, data: bytes | None, expected_content_length: str | None, loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that Content-Length header is set correctly for all HTTP methods.""" req = make_client_request(method, URL("http://python.org/"), data=data, loop=loop) actual_content_length = req.headers.get(hdrs.CONTENT_LENGTH) assert actual_content_length == expected_content_length @pytest.mark.parametrize("method", ["GET", "HEAD", "OPTIONS", "TRACE"]) def test_get_methods_classification(method: str) -> None: """Test that GET-like methods are correctly classified.""" assert method in ClientRequest.GET_METHODS @pytest.mark.parametrize("method", ["POST", "PUT", "PATCH", "DELETE"]) def test_non_get_methods_classification(method: str) -> None: """Test that POST-like methods are not in GET_METHODS.""" assert method not in ClientRequest.GET_METHODS async def test_content_length_with_string_data( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test Content-Length when data is a string.""" data = "Hello, World!" req = make_client_request("POST", URL("http://python.org/"), data=data, loop=loop) # String should be encoded to bytes, default encoding is utf-8 assert req.headers[hdrs.CONTENT_LENGTH] == str(len(data.encode("utf-8"))) await req._close() async def test_content_length_with_async_iterable( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that async iterables use chunked encoding, not Content-Length.""" async def data_gen() -> AsyncIterator[bytes]: yield b"chunk1" # pragma: no cover req = make_client_request( "POST", URL("http://python.org/"), data=data_gen(), loop=loop ) assert hdrs.CONTENT_LENGTH not in req.headers assert req.chunked assert req.headers[hdrs.TRANSFER_ENCODING] == "chunked" await req._close() async def test_content_length_not_overridden( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that explicitly set Content-Length is not overridden.""" req = make_client_request( "POST", URL("http://python.org/"), data=b"test", headers=CIMultiDict({hdrs.CONTENT_LENGTH: "100"}), loop=loop, ) # Should keep the explicitly set value assert req.headers[hdrs.CONTENT_LENGTH] == "100" await req._close() async def test_content_length_with_formdata( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test Content-Length with FormData.""" form = aiohttp.FormData() form.add_field("field", "value") req = make_client_request("POST", URL("http://python.org/"), data=form, loop=loop) # FormData with known size should set Content-Length assert hdrs.CONTENT_LENGTH in req.headers await req._close() async def test_no_content_length_with_chunked( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that chunked encoding prevents Content-Length header.""" req = make_client_request( "POST", URL("http://python.org/"), data=b"test", chunked=True, loop=loop, ) assert hdrs.CONTENT_LENGTH not in req.headers assert req.headers[hdrs.TRANSFER_ENCODING] == "chunked" await req._close() @pytest.mark.parametrize("method", ["POST", "PUT", "PATCH", "DELETE"]) async def test_update_body_none_sets_content_length_zero( # type: ignore[misc] method: str, loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that updating body to None sets Content-Length: 0 for POST-like methods.""" # Create request with initial body req = make_client_request( method, URL("http://python.org/"), data=b"initial", loop=loop ) assert req.headers[hdrs.CONTENT_LENGTH] == "7" # Update body to None await req.update_body(None) assert req.headers[hdrs.CONTENT_LENGTH] == "0" await req._close() @pytest.mark.parametrize("method", ["GET", "HEAD", "OPTIONS", "TRACE"]) async def test_update_body_none_no_content_length_for_get_methods( # type: ignore[misc] method: str, loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: """Test that updating body to None doesn't set Content-Length for GET-like methods.""" # Create request with initial body req = make_client_request( method, URL("http://python.org/"), data=b"initial", loop=loop ) assert req.headers[hdrs.CONTENT_LENGTH] == "7" # Update body to None await req.update_body(None) assert hdrs.CONTENT_LENGTH not in req.headers await req._close() async def test_multiple_requests_share_empty_body_safely( make_client_request: _RequestMaker, ) -> None: """Test that multiple ClientRequest objects safely share the empty body payload.""" requests: list[ClientRequest] = [] for i in range(5): req = make_client_request("GET", URL(f"http://example.com/path{i}")) requests.append(req) empty_body = ClientRequest._EMPTY_BODY for i, req in enumerate(requests): assert req.body is empty_body, f"Request {i} has different empty body" assert req.body.size == 0 assert req.body.consumed is False assert empty_body.consumed is False assert empty_body.size == 0 async def test_empty_body_isolation_after_update( make_client_request: _RequestMaker, ) -> None: """Test that updating one request's body doesn't affect other requests.""" req1 = make_client_request("POST", URL("http://example.com/1")) req2 = make_client_request("POST", URL("http://example.com/2")) assert req1.body is ClientRequest._EMPTY_BODY assert req2.body is ClientRequest._EMPTY_BODY await req1.update_body(b"new data") assert req1.body is not ClientRequest._EMPTY_BODY assert req1.body.size == 8 assert req2.body is ClientRequest._EMPTY_BODY assert req2.body.size == 0 assert req2.body.consumed is False assert ClientRequest._EMPTY_BODY.consumed is False assert ClientRequest._EMPTY_BODY.size == 0 ================================================ FILE: tests/test_client_response.py ================================================ # Tests for aiohttp/client.py import asyncio import gc import sys from http.cookies import SimpleCookie from json import JSONDecodeError from unittest import mock import pytest from multidict import CIMultiDict, CIMultiDictProxy from pytest_mock import MockerFixture from yarl import URL import aiohttp from aiohttp import ClientSession, hdrs, http from aiohttp.client_reqrep import ClientResponse from aiohttp.connector import Connection from aiohttp.helpers import TimerNoop from aiohttp.multipart import BadContentDispositionHeader from aiohttp.tracing import Trace class WriterMock(mock.AsyncMock): def done(self) -> bool: return True @pytest.fixture def session() -> mock.Mock: return mock.Mock() async def test_http_processing_error(session: ClientSession) -> None: loop = mock.Mock() url = URL("http://del-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) loop.get_debug = mock.Mock() loop.get_debug.return_value = True connection = mock.Mock() connection.protocol = aiohttp.DataQueue(loop) connection.protocol.set_exception(http.HttpProcessingError()) with pytest.raises(aiohttp.ClientResponseError) as info: await response.start(connection) assert info.value.request_info.url is url response.close() def test_del(session: ClientSession) -> None: loop = mock.Mock() url = URL("http://del-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) loop.get_debug = mock.Mock() loop.get_debug.return_value = True connection = mock.Mock() response._closed = False response._connection = connection loop.set_exception_handler(lambda loop, ctx: None) with pytest.warns(ResourceWarning): del response gc.collect() connection.release.assert_called_with() def test_close(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) response._closed = False response._connection = mock.Mock() response.close() assert response.connection is None response.close() response.close() def test_wait_for_100_1( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://python.org") response = ClientResponse( "get", url, continue100=loop.create_future(), writer=WriterMock(), timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) assert response._continue is not None response.close() def test_wait_for_100_2( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://python.org") response = ClientResponse( "get", url, continue100=None, writer=WriterMock(), timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) assert response._continue is None response.close() def test_repr(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) response.status = 200 response.reason = "Ok" assert "" in repr(response) def test_repr_non_ascii_url() -> None: url = URL("http://fake-host.org/\u03bb") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) assert "" in repr(response) def test_repr_non_ascii_reason() -> None: url = URL("http://fake-host.org/path") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) response.reason = "\u03bb" assert "" in repr( response ) async def test_read_and_release_connection( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result(b"payload") return fut content = response.content = mock.Mock() content.read.side_effect = side_effect res = await response.read() assert res == b"payload" assert response._connection is None async def test_read_and_release_connection_with_error( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) content = response.content = mock.Mock() content.read.return_value = loop.create_future() content.read.return_value.set_exception(ValueError) with pytest.raises(ValueError): await response.read() assert response._closed async def test_release(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) fut = loop.create_future() fut.set_result(b"") content = response.content = mock.Mock() content.readany.return_value = fut response.release() assert response._connection is None @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Other implementations has different GC strategies", ) async def test_release_on_del( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: connection = mock.Mock() connection.protocol.upgraded = False def run(conn: Connection) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) response._closed = False response._connection = conn run(connection) assert connection.release.called async def test_response_eof( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) response._closed = False conn = response._connection = mock.Mock() conn.protocol.upgraded = False response._response_eof() assert conn.release.called assert response._connection is None async def test_response_eof_upgraded( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) conn = response._connection = mock.Mock() conn.protocol.upgraded = True response._response_eof() assert not conn.release.called assert response._connection is conn async def test_response_eof_after_connection_detach( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) response._closed = False conn = response._connection = mock.Mock() conn.protocol = None response._response_eof() assert conn.release.called assert response._connection is None async def test_text(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect res = await response.text() assert res == '{"тест": "пройден"}' assert response._connection is None async def test_text_bad_encoding( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тестkey": "пройденvalue"}'.encode("cp1251")) return fut # lie about the encoding h = {"Content-Type": "application/json;charset=utf-8"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect with pytest.raises(UnicodeDecodeError): await response.text() # only the valid utf-8 characters will be returned res = await response.text(errors="ignore") assert res == '{"key": "value"}' assert response._connection is None async def test_text_badly_encoded_encoding_header( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: session._resolve_charset = lambda *_: "utf-8" url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result(b"foo") return fut h = {"Content-Type": "text/html; charset=\udc81gutf-8\udc81\udc8d"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect await response.read() encoding = response.get_encoding() assert encoding == "utf-8" async def test_text_custom_encoding( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut h = {"Content-Type": "application/json"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect with mock.patch.object(response, "get_encoding") as m: res = await response.text(encoding="cp1251") assert res == '{"тест": "пройден"}' assert response._connection is None assert not m.called @pytest.mark.parametrize("content_type", ("text/plain", "text/plain;charset=invalid")) async def test_text_charset_resolver( content_type: str, loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: session._resolve_charset = lambda r, b: "cp1251" url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut h = {"Content-Type": content_type} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect await response.read() res = await response.text() assert res == '{"тест": "пройден"}' assert response._connection is None assert response.get_encoding() == "cp1251" async def test_get_encoding_body_none( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Type": "text/html"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = AssertionError with pytest.raises( RuntimeError, match="^Cannot compute fallback encoding of a not yet read body$", ): response.get_encoding() assert response.closed async def test_text_after_read( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect res = await response.text() assert res == '{"тест": "пройден"}' assert response._connection is None async def test_json(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect res = await response.json() assert res == {"тест": "пройден"} assert response._connection is None async def test_json_extended_content_type( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut h = {"Content-Type": "application/this.is-1_content+subtype+json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect res = await response.json() assert res == {"тест": "пройден"} assert response._connection is None async def test_json_custom_content_type( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut h = {"Content-Type": "custom/type;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect res = await response.json(content_type="custom/type") assert res == {"тест": "пройден"} assert response._connection is None async def test_json_custom_loader( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) response._body = b"data" def custom(content: str) -> str: return content + "-custom" res = await response.json(loads=custom) assert res == "data-custom" async def test_json_invalid_content_type( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Type": "data/octet-stream"} response._headers = CIMultiDictProxy(CIMultiDict(h)) response._body = b"" response.status = 500 with pytest.raises(aiohttp.ContentTypeError) as info: await response.json() assert info.value.request_info == response.request_info assert info.value.status == 500 async def test_json_no_content( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Type": "application/json"} response._headers = CIMultiDictProxy(CIMultiDict(h)) response._body = b"" with pytest.raises(JSONDecodeError): await response.json(content_type=None) async def test_json_override_encoding( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result('{"тест": "пройден"}'.encode("cp1251")) return fut h = {"Content-Type": "application/json;charset=utf8"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect with mock.patch.object(response, "get_encoding") as m: res = await response.json(encoding="cp1251") assert res == {"тест": "пройден"} assert response._connection is None assert not m.called def test_get_encoding_unknown( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Type": "application/json"} response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.get_encoding() == "utf-8" def test_raise_for_status_2xx() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) response.status = 200 response.reason = "OK" response.raise_for_status() # should not raise def test_raise_for_status_4xx() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) response.status = 409 response.reason = "CONFLICT" with pytest.raises(aiohttp.ClientResponseError) as cm: response.raise_for_status() assert str(cm.value.status) == "409" assert str(cm.value.message) == "CONFLICT" assert response.closed def test_raise_for_status_4xx_without_reason() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) response.status = 404 response.reason = "" with pytest.raises(aiohttp.ClientResponseError) as cm: response.raise_for_status() assert str(cm.value.status) == "404" assert str(cm.value.message) == "" assert response.closed def test_resp_host() -> None: url = URL("http://del-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) assert "del-cl-resp.org" == response.host def test_content_type() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) assert "application/json" == response.content_type def test_content_type_no_header() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict({})) assert "application/octet-stream" == response.content_type def test_charset() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) assert "cp1251" == response.charset def test_charset_no_header() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict({})) assert response.charset is None def test_charset_no_charset() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Type": "application/json"} response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.charset is None def test_content_disposition_full() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Disposition": 'attachment; filename="archive.tar.gz"; foo=bar'} response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.content_disposition is not None assert "attachment" == response.content_disposition.type assert "bar" == response.content_disposition.parameters["foo"] assert "archive.tar.gz" == response.content_disposition.filename with pytest.raises(TypeError): response.content_disposition.parameters["foo"] = "baz" # type: ignore[index] def test_content_disposition_no_parameters() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Disposition": "attachment"} response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.content_disposition is not None assert "attachment" == response.content_disposition.type assert response.content_disposition.filename is None assert {} == response.content_disposition.parameters @pytest.mark.parametrize( "content_disposition", ( 'attachment; filename="archive.tar.gz";', 'attachment;; filename="archive.tar.gz"', ), ) def test_content_disposition_empty_parts(content_disposition: str) -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) h = {"Content-Disposition": content_disposition} response._headers = CIMultiDictProxy(CIMultiDict(h)) with pytest.warns(BadContentDispositionHeader): assert response.content_disposition is not None assert "attachment" == response.content_disposition.type assert "archive.tar.gz" == response.content_disposition.filename def test_content_disposition_no_header() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict({})) assert response.content_disposition is None def test_default_encoding_is_utf8() -> None: url = URL("http://def-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=None, request_headers=CIMultiDict[str](), original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict({})) response._body = b"" assert response.get_encoding() == "utf-8" def test_response_request_info() -> None: url = URL("http://def-cl-resp.org") h = {"Content-Type": "application/json;charset=cp1251"} headers = CIMultiDict(h) response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=headers, original_url=url, ) assert url == response.request_info.url assert "get" == response.request_info.method assert headers == response.request_info.headers def test_request_info_in_exception() -> None: url = URL("http://def-cl-resp.org") h = {"Content-Type": "application/json;charset=cp1251"} headers = CIMultiDict(h) response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=headers, original_url=url, ) response.status = 409 response.reason = "CONFLICT" with pytest.raises(aiohttp.ClientResponseError) as cm: response.raise_for_status() assert cm.value.request_info == response.request_info def test_no_redirect_history_in_exception() -> None: url = URL("http://def-cl-resp.org") h = {"Content-Type": "application/json;charset=cp1251"} headers = CIMultiDict(h) response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=headers, original_url=url, ) response.status = 409 response.reason = "CONFLICT" with pytest.raises(aiohttp.ClientResponseError) as cm: response.raise_for_status() assert () == cm.value.history def test_redirect_history_in_exception() -> None: hist_url = URL("http://def-cl-resp.org") u = "http://def-cl-resp.org/index.htm" url = URL(u) hist_headers = {"Content-Type": "application/json;charset=cp1251", "Location": u} h = {"Content-Type": "application/json;charset=cp1251"} headers = CIMultiDict(h) response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=headers, original_url=url, ) response.status = 409 response.reason = "CONFLICT" hist_response = ClientResponse( "get", hist_url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=headers, original_url=hist_url, ) hist_response._headers = CIMultiDictProxy(CIMultiDict(hist_headers)) hist_response.status = 301 hist_response.reason = "REDIRECT" response._history = (hist_response,) with pytest.raises(aiohttp.ClientResponseError) as cm: response.raise_for_status() assert (hist_response,) == cm.value.history async def test_response_read_triggers_callback( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: trace = mock.create_autospec(Trace, instance=True, spec_set=True) response_method = "get" response_url = URL("http://def-cl-resp.org") response_body = b"This is response" response = ClientResponse( response_method, response_url, writer=WriterMock(), continue100=None, timer=TimerNoop(), loop=loop, session=session, traces=[trace], request_headers=CIMultiDict[str](), original_url=response_url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": fut = loop.create_future() fut.set_result(response_body) return fut h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) content = response.content = mock.Mock() content.read.side_effect = side_effect res = await response.read() assert res == response_body assert response._connection is None assert trace.send_response_chunk_received.called assert trace.send_response_chunk_received.call_args == mock.call( response_method, response_url, response_body ) def test_response_cookies( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://python.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) cookies = response.cookies # Ensure the same cookies object is returned each time assert response.cookies is cookies def test_response_real_url( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org/#urlfragment") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) assert response.url == url.with_fragment(None) assert response.real_url == url def test_response_links_comma_separated( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org/") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = ( ( "Link", ( "; rel=next, " "; rel=home" ), ), ) response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.links == { "next": {"url": URL("http://example.com/page/1.html"), "rel": "next"}, "home": {"url": URL("http://example.com/"), "rel": "home"}, } def test_response_links_multiple_headers( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org/") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = ( ("Link", "; rel=next"), ("Link", "; rel=home"), ) response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.links == { "next": {"url": URL("http://example.com/page/1.html"), "rel": "next"}, "home": {"url": URL("http://example.com/"), "rel": "home"}, } def test_response_links_no_rel( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org/") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = (("Link", ""),) response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.links == { "http://example.com/": {"url": URL("http://example.com/")} } def test_response_links_quoted( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org/") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = (("Link", '; rel="home-page"'),) response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.links == { "home-page": {"url": URL("http://example.com/"), "rel": "home-page"} } def test_response_links_relative( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org/") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) h = (("Link", "; rel=rel"),) response._headers = CIMultiDictProxy(CIMultiDict(h)) assert response.links == { "rel": {"url": URL("http://def-cl-resp.org/relative/path"), "rel": "rel"} } def test_response_links_empty( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: url = URL("http://def-cl-resp.org/") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict()) assert response.links == {} def test_response_not_closed_after_get_ok(mocker: MockerFixture) -> None: url = URL("http://del-cl-resp.org") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) response.status = 400 response.reason = "Bad Request" response._closed = False spy = mocker.spy(response, "raise_for_status") assert not response.ok assert not response.closed assert spy.call_count == 0 def test_response_duplicate_cookie_names( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: """ Test that response.cookies handles duplicate cookie names correctly. Note: This behavior (losing cookies with same name but different domains/paths) is arguably undesirable, but we promise to return a SimpleCookie object, and SimpleCookie uses cookie name as the key. This is documented behavior. To access all cookies including duplicates, users should use: - response.headers.getall('Set-Cookie') for raw headers - The session's cookie jar correctly stores all cookies """ url = URL("http://example.com") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) # Set headers with duplicate cookie names but different domains headers = CIMultiDict( [ ( "Set-Cookie", "session-id=123-4567890; Domain=.example.com; Path=/; Secure", ), ("Set-Cookie", "session-id=098-7654321; Domain=.www.example.com; Path=/"), ("Set-Cookie", "user-pref=dark; Domain=.example.com; Path=/"), ("Set-Cookie", "user-pref=light; Domain=api.example.com; Path=/"), ] ) response._headers = CIMultiDictProxy(headers) # Set raw cookie headers as done in ClientResponse.start() response._raw_cookie_headers = tuple(headers.getall("Set-Cookie", [])) # SimpleCookie only keeps the last cookie with each name # This is expected behavior since SimpleCookie uses name as the key assert len(response.cookies) == 2 # Only 'session-id' and 'user-pref' assert response.cookies["session-id"].value == "098-7654321" # Last one wins assert response.cookies["user-pref"].value == "light" # Last one wins def test_response_raw_cookie_headers_preserved( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: """Test that raw Set-Cookie headers are preserved in _raw_cookie_headers.""" url = URL("http://example.com") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) # Set headers with multiple cookies cookie_headers = [ "session-id=123; Domain=.example.com; Path=/; Secure", "session-id=456; Domain=.www.example.com; Path=/", "tracking=xyz; Domain=.example.com; Path=/; HttpOnly", ] headers: CIMultiDict[str] = CIMultiDict() for cookie_hdr in cookie_headers: headers.add("Set-Cookie", cookie_hdr) response._headers = CIMultiDictProxy(headers) # Set raw cookie headers as done in ClientResponse.start() response._raw_cookie_headers = tuple(response.headers.getall(hdrs.SET_COOKIE, [])) # Verify raw headers are preserved assert response._raw_cookie_headers == tuple(cookie_headers) assert len(response._raw_cookie_headers) == 3 # But SimpleCookie only has unique names assert len(response.cookies) == 2 # 'session-id' and 'tracking' def test_response_cookies_setter_updates_raw_headers( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: """Test that setting cookies property updates _raw_cookie_headers.""" url = URL("http://example.com") response = ClientResponse( "get", url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, request_headers=CIMultiDict[str](), original_url=url, ) # Create a SimpleCookie with some cookies cookies = SimpleCookie() cookies["session-id"] = "123456" cookies["session-id"]["domain"] = ".example.com" cookies["session-id"]["path"] = "/" cookies["session-id"]["secure"] = True cookies["tracking"] = "xyz789" cookies["tracking"]["domain"] = ".example.com" cookies["tracking"]["httponly"] = True # Set the cookies property response.cookies = cookies # Verify _raw_cookie_headers was updated assert response._raw_cookie_headers is not None assert len(response._raw_cookie_headers) == 2 assert isinstance(response._raw_cookie_headers, tuple) # Check the raw headers contain the expected cookie strings raw_headers = list(response._raw_cookie_headers) assert any("session-id=123456" in h for h in raw_headers) assert any("tracking=xyz789" in h for h in raw_headers) assert any("Secure" in h for h in raw_headers) assert any("HttpOnly" in h for h in raw_headers) # Verify cookies property returns the same object assert response.cookies is cookies # Test setting empty cookies empty_cookies = SimpleCookie() response.cookies = empty_cookies # Should not set _raw_cookie_headers for empty cookies assert response._raw_cookie_headers is None ================================================ FILE: tests/test_client_session.py ================================================ import asyncio import contextlib import gc import io import json import sys import warnings from collections import deque from collections.abc import Awaitable, Callable, Iterator from http.cookies import BaseCookie, SimpleCookie from types import SimpleNamespace from typing import Any, NoReturn, TypedDict, cast from unittest import mock from uuid import uuid4 import pytest from multidict import CIMultiDict, MultiDict from pytest_mock import MockerFixture from yarl import URL import aiohttp from aiohttp import abc, client, hdrs, tracing, web from aiohttp.client import ClientSession from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequest, ConnectionKey from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector from aiohttp.cookiejar import CookieJar from aiohttp.http import RawResponseMessage from aiohttp.payload import Payload from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.test_utils import TestServer from aiohttp.tracing import ( Trace, TraceRequestChunkSentParams, TraceRequestEndParams, TraceRequestExceptionParams, TraceRequestHeadersSentParams, TraceRequestRedirectParams, TraceRequestStartParams, TraceResponseChunkReceivedParams, ) class _Params(TypedDict): headers: dict[str, str] max_redirects: int compress: str chunked: bool expect100: bool read_until_eof: bool @pytest.fixture def connector( loop: asyncio.AbstractEventLoop, create_mocked_conn: Callable[[], ResponseHandler] ) -> Iterator[BaseConnector]: async def make_conn() -> BaseConnector: return BaseConnector() key = ConnectionKey("localhost", 80, False, True, None, None, None) conn = loop.run_until_complete(make_conn()) proto = create_mocked_conn() conn._conns[key] = deque([(proto, 123)]) yield conn loop.run_until_complete(conn.close()) @pytest.fixture def create_session( loop: asyncio.AbstractEventLoop, ) -> Iterator[Callable[..., Awaitable[ClientSession]]]: session = None async def maker(*args: Any, **kwargs: Any) -> ClientSession: nonlocal session session = ClientSession(*args, **kwargs) return session yield maker if session is not None: loop.run_until_complete(session.close()) @pytest.fixture def session( create_session: Callable[..., Awaitable[ClientSession]], loop: asyncio.AbstractEventLoop, ) -> ClientSession: return loop.run_until_complete(create_session()) @pytest.fixture def params() -> _Params: return dict( headers={"Authorization": "Basic ..."}, max_redirects=2, compress="deflate", chunked=True, expect100=True, read_until_eof=False, ) @pytest.fixture async def auth_server(aiohttp_server: AiohttpServer) -> TestServer: """Create a server with an auth handler that returns auth header or 'no_auth'.""" async def handler(request: web.Request) -> web.Response: auth_header = request.headers.get(hdrs.AUTHORIZATION) if auth_header: return web.Response(text=f"auth:{auth_header}") return web.Response(text="no_auth") app = web.Application() app.router.add_get("/", handler) return await aiohttp_server(app) async def test_close_coro( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session() await session.close() async def test_init_headers_simple_dict( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session(headers={"h1": "header1", "h2": "header2"}) assert sorted(session.headers.items()) == ([("h1", "header1"), ("h2", "header2")]) async def test_init_headers_list_of_tuples( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session( headers=[("h1", "header1"), ("h2", "header2"), ("h3", "header3")] ) assert session.headers == CIMultiDict( [("h1", "header1"), ("h2", "header2"), ("h3", "header3")] ) async def test_init_headers_MultiDict( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session( headers=MultiDict([("h1", "header1"), ("h2", "header2"), ("h3", "header3")]) ) assert session.headers == CIMultiDict( [("H1", "header1"), ("H2", "header2"), ("H3", "header3")] ) async def test_init_headers_list_of_tuples_with_duplicates( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session( headers=[("h1", "header11"), ("h2", "header21"), ("h1", "header12")] ) assert session.headers == CIMultiDict( [("H1", "header11"), ("H2", "header21"), ("H1", "header12")] ) async def test_init_cookies_with_simple_dict( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session(cookies={"c1": "cookie1", "c2": "cookie2"}) cookies = session.cookie_jar.filter_cookies(URL()) assert set(cookies) == {"c1", "c2"} assert cookies["c1"].value == "cookie1" assert cookies["c2"].value == "cookie2" async def test_init_cookies_with_list_of_tuples( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session(cookies=[("c1", "cookie1"), ("c2", "cookie2")]) cookies = session.cookie_jar.filter_cookies(URL()) assert set(cookies) == {"c1", "c2"} assert cookies["c1"].value == "cookie1" assert cookies["c2"].value == "cookie2" async def test_merge_headers( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: # Check incoming simple dict session = await create_session(headers={"h1": "header1", "h2": "header2"}) headers = session._prepare_headers({"h1": "h1"}) assert isinstance(headers, CIMultiDict) assert headers == {"h1": "h1", "h2": "header2"} async def test_merge_headers_with_multi_dict( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session(headers={"h1": "header1", "h2": "header2"}) headers = session._prepare_headers(MultiDict([("h1", "h1")])) assert isinstance(headers, CIMultiDict) assert headers == {"h1": "h1", "h2": "header2"} async def test_merge_headers_with_list_of_tuples( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session(headers={"h1": "header1", "h2": "header2"}) headers = session._prepare_headers([("h1", "h1")]) assert isinstance(headers, CIMultiDict) assert headers == {"h1": "h1", "h2": "header2"} async def test_merge_headers_with_list_of_tuples_duplicated_names( create_session: Callable[..., Awaitable[ClientSession]], ) -> None: session = await create_session(headers={"h1": "header1", "h2": "header2"}) headers = session._prepare_headers([("h1", "v1"), ("h1", "v2")]) assert isinstance(headers, CIMultiDict) assert list(sorted(headers.items())) == [ ("h1", "v1"), ("h1", "v2"), ("h2", "header2"), ] @pytest.mark.parametrize("obj", (object(), None)) async def test_invalid_data(session: ClientSession, obj: object) -> None: with pytest.raises(TypeError, match="expected str"): await session.post("http://example.test/", data={"some": obj}) async def test_http_GET(session: ClientSession, params: _Params) -> None: with mock.patch( "aiohttp.client.ClientSession._request", autospec=True, spec_set=True ) as patched: await session.get("http://test.example.com", params={"x": 1}, **params) assert patched.called, "`ClientSession._request` not called" assert list(patched.call_args) == [ (session, "GET", "http://test.example.com"), dict(params={"x": 1}, allow_redirects=True, **params), ] async def test_http_OPTIONS(session: ClientSession, params: _Params) -> None: with mock.patch( "aiohttp.client.ClientSession._request", autospec=True, spec_set=True ) as patched: await session.options("http://opt.example.com", params={"x": 2}, **params) assert patched.called, "`ClientSession._request` not called" assert list(patched.call_args) == [ (session, "OPTIONS", "http://opt.example.com"), dict(params={"x": 2}, allow_redirects=True, **params), ] async def test_http_HEAD(session: ClientSession, params: _Params) -> None: with mock.patch( "aiohttp.client.ClientSession._request", autospec=True, spec_set=True ) as patched: await session.head("http://head.example.com", params={"x": 2}, **params) assert patched.called, "`ClientSession._request` not called" assert list(patched.call_args) == [ (session, "HEAD", "http://head.example.com"), dict(params={"x": 2}, allow_redirects=False, **params), ] async def test_http_POST(session: ClientSession, params: _Params) -> None: with mock.patch( "aiohttp.client.ClientSession._request", autospec=True, spec_set=True ) as patched: await session.post( "http://post.example.com", params={"x": 2}, data="Some_data", **params ) assert patched.called, "`ClientSession._request` not called" assert list(patched.call_args) == [ (session, "POST", "http://post.example.com"), dict(params={"x": 2}, data="Some_data", **params), ] async def test_http_PUT(session: ClientSession, params: _Params) -> None: with mock.patch( "aiohttp.client.ClientSession._request", autospec=True, spec_set=True ) as patched: await session.put( "http://put.example.com", params={"x": 2}, data="Some_data", **params ) assert patched.called, "`ClientSession._request` not called" assert list(patched.call_args) == [ (session, "PUT", "http://put.example.com"), dict(params={"x": 2}, data="Some_data", **params), ] async def test_http_PATCH(session: ClientSession, params: _Params) -> None: with mock.patch( "aiohttp.client.ClientSession._request", autospec=True, spec_set=True ) as patched: await session.patch( "http://patch.example.com", params={"x": 2}, data="Some_data", **params ) assert patched.called, "`ClientSession._request` not called" assert list(patched.call_args) == [ (session, "PATCH", "http://patch.example.com"), dict(params={"x": 2}, data="Some_data", **params), ] async def test_http_DELETE(session: ClientSession, params: _Params) -> None: with mock.patch( "aiohttp.client.ClientSession._request", autospec=True, spec_set=True ) as patched: await session.delete("http://delete.example.com", params={"x": 2}, **params) assert patched.called, "`ClientSession._request` not called" assert list(patched.call_args) == [ (session, "DELETE", "http://delete.example.com"), dict(params={"x": 2}, **params), ] async def test_close( create_session: Callable[..., Awaitable[ClientSession]], connector: BaseConnector ) -> None: session = await create_session(connector=connector) await session.close() assert session.connector is None assert connector.closed async def test_closed(session: ClientSession) -> None: assert not session.closed await session.close() assert session.closed async def test_connector( create_session: Callable[..., Awaitable[ClientSession]], loop: asyncio.AbstractEventLoop, mocker: MockerFixture, ) -> None: connector = TCPConnector() m = mocker.spy(connector, "close") session = await create_session(connector=connector) assert session.connector is connector await session.close() assert m.called await connector.close() async def test_create_connector( create_session: Callable[..., Awaitable[ClientSession]], loop: asyncio.AbstractEventLoop, mocker: MockerFixture, ) -> None: session = await create_session() m = mocker.spy(session.connector, "close") await session.close() assert m.called @pytest.mark.skipif( sys.version_info < (3, 11), reason="Use test_ssl_shutdown_timeout_passed_to_connector_pre_311 for Python < 3.11", ) async def test_ssl_shutdown_timeout_passed_to_connector() -> None: # Test default value (no warning expected) async with ClientSession() as session: assert isinstance(session.connector, TCPConnector) assert session.connector._ssl_shutdown_timeout == 0 # Test custom value - expect deprecation warning with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): async with ClientSession(ssl_shutdown_timeout=1.0) as session: assert isinstance(session.connector, TCPConnector) assert session.connector._ssl_shutdown_timeout == 1.0 # Test None value - expect deprecation warning with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): async with ClientSession(ssl_shutdown_timeout=None) as session: assert isinstance(session.connector, TCPConnector) assert session.connector._ssl_shutdown_timeout is None # Test that it doesn't affect when custom connector is provided with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): custom_conn = TCPConnector(ssl_shutdown_timeout=2.0) with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): async with ClientSession( connector=custom_conn, ssl_shutdown_timeout=1.0 ) as session: assert session.connector is not None assert isinstance(session.connector, TCPConnector) assert ( session.connector._ssl_shutdown_timeout == 2.0 ) # Should use connector's value @pytest.mark.skipif( sys.version_info >= (3, 11), reason="This test is for Python < 3.11 runtime warning behavior", ) async def test_ssl_shutdown_timeout_passed_to_connector_pre_311() -> None: """Test that both deprecation and runtime warnings are issued on Python < 3.11.""" # Test custom value - expect both deprecation and runtime warnings with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") async with ClientSession(ssl_shutdown_timeout=1.0) as session: assert isinstance(session.connector, TCPConnector) assert session.connector._ssl_shutdown_timeout == 1.0 # Should have deprecation warnings (from ClientSession and TCPConnector) and runtime warning # ClientSession emits 1 DeprecationWarning, TCPConnector emits 1 DeprecationWarning + 1 RuntimeWarning = 3 total assert len(w) == 3 deprecation_count = sum( 1 for warn in w if issubclass(warn.category, DeprecationWarning) ) runtime_count = sum( 1 for warn in w if issubclass(warn.category, RuntimeWarning) ) assert deprecation_count == 2 # One from ClientSession, one from TCPConnector assert runtime_count == 1 # One from TCPConnector # Test with custom connector with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") custom_conn = TCPConnector(ssl_shutdown_timeout=2.0) # Should have both deprecation and runtime warnings assert len(w) == 2 with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): async with ClientSession( connector=custom_conn, ssl_shutdown_timeout=1.0 ) as session: assert session.connector is not None assert isinstance(session.connector, TCPConnector) assert ( session.connector._ssl_shutdown_timeout == 2.0 ) # Should use connector's value def test_connector_loop(loop: asyncio.AbstractEventLoop) -> None: with contextlib.ExitStack() as stack: another_loop = asyncio.new_event_loop() stack.enter_context(contextlib.closing(another_loop)) async def make_connector() -> TCPConnector: return TCPConnector() connector = another_loop.run_until_complete(make_connector()) with pytest.raises(RuntimeError) as ctx: async def make_sess() -> ClientSession: return ClientSession(connector=connector) loop.run_until_complete(make_sess()) expected = "Session and connector have to use same event loop" assert str(ctx.value).startswith(expected) another_loop.run_until_complete(connector.close()) def test_detach(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: conn = session.connector assert conn is not None try: assert not conn.closed session.detach() assert session.connector is None assert session.closed assert not conn.closed finally: loop.run_until_complete(conn.close()) async def test_request_closed_session(session: ClientSession) -> None: await session.close() with pytest.raises(RuntimeError): await session.request("get", "/") async def test_close_flag_for_closed_connector(session: ClientSession) -> None: conn = session.connector assert conn is not None assert not session.closed await conn.close() assert session.closed async def test_double_close( connector: BaseConnector, create_session: Callable[..., Awaitable[ClientSession]] ) -> None: session = await create_session(connector=connector) await session.close() assert session.connector is None await session.close() assert session.closed assert connector.closed async def test_del(connector: BaseConnector, loop: asyncio.AbstractEventLoop) -> None: loop.set_debug(False) # N.B. don't use session fixture, it stores extra reference internally session = ClientSession(connector=connector) logs = [] loop.set_exception_handler(lambda loop, ctx: logs.append(ctx)) with pytest.warns(ResourceWarning): del session gc.collect() assert len(logs) == 1 expected = {"client_session": mock.ANY, "message": "Unclosed client session"} assert logs[0] == expected async def test_del_debug( connector: BaseConnector, loop: asyncio.AbstractEventLoop ) -> None: loop.set_debug(True) # N.B. don't use session fixture, it stores extra reference internally session = ClientSession(connector=connector) logs = [] loop.set_exception_handler(lambda loop, ctx: logs.append(ctx)) with pytest.warns(ResourceWarning): del session gc.collect() assert len(logs) == 1 expected = { "client_session": mock.ANY, "message": "Unclosed client session", "source_traceback": mock.ANY, } assert logs[0] == expected async def test_borrow_connector_loop( connector: BaseConnector, create_session: Callable[..., Awaitable[ClientSession]], loop: asyncio.AbstractEventLoop, ) -> None: async with ClientSession(connector=connector) as session: assert session._loop is loop async def test_reraise_os_error( create_session: Callable[..., Awaitable[ClientSession]], create_mocked_conn: Callable[[], ResponseHandler], ) -> None: err = OSError(1, "permission error") req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req_factory = mock.Mock(return_value=req) req._send.side_effect = err req._body = mock.create_autospec(Payload, spec_set=True, instance=True) session = await create_session(request_class=req_factory) async def create_connection( req: object, traces: object, timeout: object ) -> ResponseHandler: # return self.transport, self.protocol return create_mocked_conn() with mock.patch.object(session._connector, "_create_connection", create_connection): with mock.patch.object( session._connector, "_release", autospec=True, spec_set=True ): with pytest.raises(aiohttp.ClientOSError) as ctx: await session.request("get", "http://example.com") e = ctx.value assert e.errno == err.errno assert e.strerror == err.strerror async def test_close_conn_on_error( create_session: Callable[..., Awaitable[ClientSession]], create_mocked_conn: Callable[[], ResponseHandler], ) -> None: class UnexpectedException(BaseException): pass err = UnexpectedException("permission error") req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req_factory = mock.Mock(return_value=req) req._send.side_effect = err req._body = mock.create_autospec(Payload, spec_set=True, instance=True) session = await create_session(request_class=req_factory) connections = [] assert session._connector is not None original_connect = session._connector.connect async def connect( req: ClientRequest, traces: list[Trace], timeout: aiohttp.ClientTimeout ) -> Connection: conn = await original_connect(req, traces, timeout) connections.append(conn) return conn async def create_connection( req: object, traces: object, timeout: object ) -> ResponseHandler: # return self.transport, self.protocol conn = create_mocked_conn() return conn with mock.patch.object(session._connector, "connect", connect): with mock.patch.object( session._connector, "_create_connection", create_connection ): with mock.patch.object( session._connector, "_release", autospec=True, spec_set=True ): with pytest.raises(UnexpectedException): async with session.request("get", "http://example.com"): pass # normally called during garbage collection. triggers an exception # if the connection wasn't already closed for c in connections: c.__del__() @pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss"]) async def test_ws_connect_allowed_protocols( # type: ignore[misc] create_session: Callable[..., Awaitable[ClientSession]], create_mocked_conn: Callable[[], ResponseHandler], protocol: str, ws_key: str, key_data: bytes, ) -> None: resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True, instance=True) resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } resp.url = URL(f"{protocol}://example") resp.cookies = SimpleCookie() req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req._body = None # No body for WebSocket upgrade requests req_factory = mock.Mock(return_value=req) req._send = mock.AsyncMock(return_value=resp) # BaseConnector allows all high level protocols by default connector = BaseConnector() session = await create_session(connector=connector, request_class=req_factory) connections = [] assert session._connector is not None original_connect = session._connector.connect async def connect( req: ClientRequest, traces: list[Trace], timeout: aiohttp.ClientTimeout ) -> Connection: conn = await original_connect(req, traces, timeout) connections.append(conn) return conn async def create_connection( req: object, traces: object, timeout: object ) -> ResponseHandler: return create_mocked_conn() connector = session._connector with ( mock.patch.object(connector, "connect", connect), mock.patch.object(connector, "_create_connection", create_connection), mock.patch.object(connector, "_release"), mock.patch("aiohttp.client.os") as m_os, ): m_os.urandom.return_value = key_data await session.ws_connect(f"{protocol}://example") # normally called during garbage collection. triggers an exception # if the connection wasn't already closed for c in connections: c.close() c.__del__() await session.close() @pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"]) async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc] create_session: Callable[..., Awaitable[ClientSession]], create_mocked_conn: Callable[[], ResponseHandler], protocol: str, ws_key: str, key_data: bytes, ) -> None: resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True, instance=True) resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } resp.url = URL(f"{protocol}://example") resp.cookies = SimpleCookie() req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req._body = None # No body for WebSocket upgrade requests req_factory = mock.Mock(return_value=req) req._send = mock.AsyncMock(return_value=resp) # UnixConnector allows all high level protocols by default and unix sockets session = await create_session( connector=UnixConnector(path=""), request_class=req_factory ) connections = [] assert session._connector is not None original_connect = session._connector.connect async def connect( req: ClientRequest, traces: list[Trace], timeout: aiohttp.ClientTimeout ) -> Connection: conn = await original_connect(req, traces, timeout) connections.append(conn) return conn async def create_connection( req: object, traces: object, timeout: object ) -> ResponseHandler: return create_mocked_conn() connector = session._connector with ( mock.patch.object(connector, "connect", connect), mock.patch.object(connector, "_create_connection", create_connection), mock.patch.object(connector, "_release"), mock.patch("aiohttp.client.os") as m_os, ): m_os.urandom.return_value = key_data await session.ws_connect(f"{protocol}://example") # normally called during garbage collection. triggers an exception # if the connection wasn't already closed for c in connections: c.close() c.__del__() await session.close() async def test_cookie_jar_usage( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: req_url = None class MockCookieJar(abc.AbstractCookieJar): def __init__(self) -> None: self._update_cookies_mock = mock.Mock() self._filter_cookies_mock = mock.Mock(return_value=BaseCookie()) self._clear_mock = mock.Mock() self._clear_domain_mock = mock.Mock() self._items: list[Any] = [] @property def quote_cookie(self) -> bool: return True def clear(self, predicate: abc.ClearCookiePredicate | None = None) -> None: self._clear_mock(predicate) def clear_domain(self, domain: str) -> None: self._clear_domain_mock(domain) def update_cookies(self, cookies: Any, response_url: URL = URL()) -> None: self._update_cookies_mock(cookies, response_url) def filter_cookies(self, request_url: URL) -> BaseCookie[str]: return cast(BaseCookie[str], self._filter_cookies_mock(request_url)) def __len__(self) -> int: return len(self._items) def __iter__(self) -> Iterator[Any]: return iter(self._items) jar = MockCookieJar() assert jar.quote_cookie is True assert len(jar) == 0 assert list(jar) == [] jar.clear() jar.clear_domain("example.com") async def handler(request: web.Request) -> web.Response: nonlocal req_url req_url = "http://%s/" % request.host resp = web.Response() resp.set_cookie("response", "resp_value") return resp app = web.Application() app.router.add_route("GET", "/", handler) session = await aiohttp_client( app, cookies={"request": "req_value"}, cookie_jar=jar ) # Updating the cookie jar with initial user defined cookies jar._update_cookies_mock.assert_called_with({"request": "req_value"}, URL()) jar._update_cookies_mock.reset_mock() resp = await session.get("/") resp.release() assert req_url is not None # Filtering the cookie jar before sending the request, # getting the request URL as only parameter jar._filter_cookies_mock.assert_called_with(URL(req_url)) # Updating the cookie jar with the response cookies assert jar._update_cookies_mock.called resp_cookies = jar._update_cookies_mock.call_args[0][0] # Now update_cookies is called with a list of tuples assert isinstance(resp_cookies, list) assert len(resp_cookies) == 1 assert resp_cookies[0][0] == "response" assert resp_cookies[0][1].value == "resp_value" async def test_cookies_with_not_quoted_cookie_jar( aiohttp_server: AiohttpServer, ) -> None: async def handler(_: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) jar = CookieJar(quote_cookie=False) cookies = {"name": "val=foobar"} async with aiohttp.ClientSession(cookie_jar=jar) as sess: resp = await sess.request("GET", server.make_url("/"), cookies=cookies) assert resp.request_info.headers.get("Cookie", "") == "name=val=foobar" async def test_session_default_version(loop: asyncio.AbstractEventLoop) -> None: session = aiohttp.ClientSession() assert session.version == aiohttp.HttpVersion11 await session.close() async def test_proxy_str(session: ClientSession, params: _Params) -> None: with mock.patch( "aiohttp.client.ClientSession._request", autospec=True, spec_set=True ) as patched: await session.get("http://test.example.com", proxy="http://proxy.com", **params) assert patched.called, "`ClientSession._request` not called" assert list(patched.call_args) == [ (session, "GET", "http://test.example.com"), dict(allow_redirects=True, proxy="http://proxy.com", **params), ] async def test_default_proxy(loop: asyncio.AbstractEventLoop) -> None: proxy_url = URL("http://proxy.example.com") proxy_auth = mock.Mock() proxy_url2 = URL("http://proxy.example2.com") proxy_auth2 = mock.Mock() class OnCall(Exception): pass request_class_mock = mock.Mock(side_effect=OnCall()) session = ClientSession( proxy=proxy_url, proxy_auth=proxy_auth, request_class=request_class_mock ) assert session._default_proxy == proxy_url, "`ClientSession._default_proxy` not set" assert ( session._default_proxy_auth == proxy_auth ), "`ClientSession._default_proxy_auth` not set" with pytest.raises(OnCall): await session.get( "http://example.com", ) assert request_class_mock.called, "request class not called" assert ( request_class_mock.call_args[1].get("proxy") == proxy_url ), "`ClientSession._request` uses default proxy not one used in ClientSession.get" assert ( request_class_mock.call_args[1].get("proxy_auth") == proxy_auth ), "`ClientSession._request` uses default proxy_auth not one used in ClientSession.get" request_class_mock.reset_mock() with pytest.raises(OnCall): await session.get( "http://example.com", proxy=proxy_url2, proxy_auth=proxy_auth2 ) assert request_class_mock.called, "request class not called" assert ( request_class_mock.call_args[1].get("proxy") == proxy_url2 ), "`ClientSession._request` uses default proxy not one used in ClientSession.get" assert ( request_class_mock.call_args[1].get("proxy_auth") == proxy_auth2 ), "`ClientSession._request` uses default proxy_auth not one used in ClientSession.get" await session.close() async def test_request_tracing( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.Response: return web.json_response({"ok": True}) # Define callback signatures async def on_request_start_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestStartParams, ) -> None: """Mock signature""" async def on_request_end_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestEndParams, ) -> None: """Mock signature""" async def on_request_redirect_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestRedirectParams, ) -> None: """Mock signature""" app = web.Application() app.router.add_post("/", handler) trace_config_ctx = mock.Mock() body = "This is request body" gathered_req_headers: CIMultiDict[str] = CIMultiDict() # Create mocks with signatures(above) on_request_start = mock.create_autospec(on_request_start_callback, spec_set=True) on_request_end = mock.create_autospec(on_request_end_callback, spec_set=True) on_request_redirect = mock.create_autospec( on_request_redirect_callback, spec_set=True ) with io.BytesIO() as gathered_req_body, io.BytesIO() as gathered_res_body: async def on_request_chunk_sent( session: object, context: object, params: tracing.TraceRequestChunkSentParams, ) -> None: gathered_req_body.write(params.chunk) async def on_response_chunk_received( session: object, context: object, params: tracing.TraceResponseChunkReceivedParams, ) -> None: gathered_res_body.write(params.chunk) async def on_request_headers_sent( session: object, context: object, params: tracing.TraceRequestHeadersSentParams, ) -> None: gathered_req_headers.extend(params.headers) trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_request_start.append(on_request_start) trace_config.on_request_end.append(on_request_end) trace_config.on_request_chunk_sent.append(on_request_chunk_sent) trace_config.on_response_chunk_received.append(on_response_chunk_received) trace_config.on_request_redirect.append(on_request_redirect) trace_config.on_request_headers_sent.append(on_request_headers_sent) headers = CIMultiDict({"Custom-Header": "Custom value"}) session = await aiohttp_client( app, trace_configs=[trace_config], headers=headers ) async with session.post("/", data=body, trace_request_ctx={}) as resp: await resp.json() on_request_start.assert_called_once_with( session.session, trace_config_ctx, aiohttp.TraceRequestStartParams( hdrs.METH_POST, session.make_url("/"), headers ), ) on_request_end.assert_called_once_with( session.session, trace_config_ctx, aiohttp.TraceRequestEndParams( hdrs.METH_POST, session.make_url("/"), headers, resp ), ) assert not on_request_redirect.called assert gathered_req_body.getvalue() == body.encode("utf8") assert gathered_res_body.getvalue() == json.dumps({"ok": True}).encode( "utf8" ) assert gathered_req_headers["Custom-Header"] == "Custom value" async def test_request_tracing_url_params( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def root_handler(request: web.Request) -> web.Response: return web.Response() async def redirect_handler(request: web.Request) -> NoReturn: raise web.HTTPFound("/") # Define callback signatures async def on_request_start_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestStartParams, ) -> None: """Mock signature""" async def on_request_end_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestEndParams, ) -> None: """Mock signature""" async def on_request_redirect_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestRedirectParams, ) -> None: """Mock signature""" async def on_request_exception_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestExceptionParams, ) -> None: """Mock signature""" async def on_request_chunk_sent_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestChunkSentParams, ) -> None: """Mock signature""" async def on_response_chunk_received_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceResponseChunkReceivedParams, ) -> None: """Mock signature""" async def on_request_headers_sent_callback( session: ClientSession, trace_config_ctx: SimpleNamespace, params: TraceRequestHeadersSentParams, ) -> None: """Mock signature""" app = web.Application() app.router.add_get("/", root_handler) app.router.add_get("/redirect", redirect_handler) on_request_start = mock.create_autospec(on_request_start_callback, spec_set=True) on_request_redirect = mock.create_autospec( on_request_redirect_callback, spec_set=True ) on_request_end = mock.create_autospec(on_request_end_callback, spec_set=True) on_request_exception = mock.create_autospec( on_request_exception_callback, spec_set=True ) on_request_chunk_sent = mock.create_autospec( on_request_chunk_sent_callback, spec_set=True ) on_response_chunk_received = mock.create_autospec( on_response_chunk_received_callback, spec_set=True ) on_request_headers_sent = mock.create_autospec( on_request_headers_sent_callback, spec_set=True ) mocks = [ on_request_start, on_request_redirect, on_request_end, on_request_exception, on_request_chunk_sent, on_response_chunk_received, on_request_headers_sent, ] trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=mock.Mock()) ) trace_config.on_request_start.append(on_request_start) trace_config.on_request_redirect.append(on_request_redirect) trace_config.on_request_end.append(on_request_end) trace_config.on_request_exception.append(on_request_exception) trace_config.on_request_chunk_sent.append(on_request_chunk_sent) trace_config.on_response_chunk_received.append(on_response_chunk_received) trace_config.on_request_headers_sent.append(on_request_headers_sent) session = await aiohttp_client(app, trace_configs=[trace_config]) def reset_mocks() -> None: for m in mocks: m.reset_mock() def to_trace_urls(mock_func: mock.Mock) -> list[URL]: return [call_args[0][-1].url for call_args in mock_func.call_args_list] def to_url(path: str) -> URL: return session.make_url(path) # Standard req: Callable[[], Awaitable[aiohttp.ClientResponse]] for req in ( lambda: session.get("/?x=0"), lambda: session.get("/", params=dict(x=0)), ): reset_mocks() async with req() as resp: await resp.text() assert to_trace_urls(on_request_start) == [to_url("/?x=0")] assert to_trace_urls(on_request_redirect) == [] assert to_trace_urls(on_request_end) == [to_url("/?x=0")] assert to_trace_urls(on_request_exception) == [] assert to_trace_urls(on_request_chunk_sent) == [] assert to_trace_urls(on_response_chunk_received) == [to_url("/?x=0")] assert to_trace_urls(on_request_headers_sent) == [to_url("/?x=0")] # Redirect for req in ( lambda: session.get("/redirect?x=0"), lambda: session.get("/redirect", params=dict(x=0)), ): reset_mocks() async with req() as resp: await resp.text() assert to_trace_urls(on_request_start) == [to_url("/redirect?x=0")] assert to_trace_urls(on_request_redirect) == [to_url("/redirect?x=0")] assert to_trace_urls(on_request_end) == [to_url("/")] assert to_trace_urls(on_request_exception) == [] assert to_trace_urls(on_request_chunk_sent) == [] assert to_trace_urls(on_response_chunk_received) == [to_url("/")] assert to_trace_urls(on_request_headers_sent) == [ to_url("/redirect?x=0"), to_url("/"), ] # Exception with mock.patch("aiohttp.client.TCPConnector.connect") as connect_patched: connect_patched.side_effect = Exception() for req in ( lambda: session.get("/?x=0"), lambda: session.get("/", params=dict(x=0)), ): reset_mocks() with contextlib.suppress(Exception): await req() assert to_trace_urls(on_request_start) == [to_url("/?x=0")] assert to_trace_urls(on_request_redirect) == [] assert to_trace_urls(on_request_end) == [] assert to_trace_urls(on_request_exception) == [to_url("?x=0")] assert to_trace_urls(on_request_chunk_sent) == [] assert to_trace_urls(on_response_chunk_received) == [] assert to_trace_urls(on_request_headers_sent) == [] async def test_request_tracing_exception() -> None: on_request_end = mock.AsyncMock() on_request_exception = mock.AsyncMock() trace_config = aiohttp.TraceConfig() trace_config.on_request_end.append(on_request_end) trace_config.on_request_exception.append(on_request_exception) with mock.patch("aiohttp.client.TCPConnector.connect") as connect_patched: error = Exception() connect_patched.side_effect = error session = aiohttp.ClientSession(trace_configs=[trace_config]) try: await session.get("http://example.com") except Exception: pass on_request_exception.assert_called_once_with( session, mock.ANY, aiohttp.TraceRequestExceptionParams( hdrs.METH_GET, URL("http://example.com"), CIMultiDict(), error ), ) assert not on_request_end.called await session.close() async def test_request_tracing_interpose_headers( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) headers: CIMultiDict[str] = CIMultiDict() class MyClientRequest(ClientRequest): def __init__(self, *args: Any, **kwargs: Any): nonlocal headers super().__init__(*args, **kwargs) headers = self.headers async def new_headers( session: object, trace_config_ctx: object, data: tracing.TraceRequestStartParams ) -> None: data.headers["foo"] = "bar" trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(new_headers) session = await aiohttp_client( app, request_class=MyClientRequest, trace_configs=[trace_config] ) await session.get("/") assert headers["foo"] == "bar" def test_client_session_inheritance() -> None: with pytest.raises(TypeError): class A(ClientSession): # type: ignore[misc] pass async def test_client_session_custom_attr() -> None: session = ClientSession() with pytest.raises(AttributeError): session.custom = None # type: ignore[attr-defined] await session.close() async def test_client_session_timeout_default_args( loop: asyncio.AbstractEventLoop, ) -> None: session1 = ClientSession() assert session1.timeout == client.DEFAULT_TIMEOUT await session1.close() async def test_client_session_timeout_zero( create_mocked_conn: Callable[[], ResponseHandler], ) -> None: async def create_connection( req: object, traces: object, timeout: object ) -> ResponseHandler: await asyncio.sleep(0.01) conn = create_mocked_conn() conn.connected = True # type: ignore[misc] assert conn.transport is not None conn.transport.is_closing.return_value = False # type: ignore[attr-defined] msg = mock.create_autospec(RawResponseMessage, spec_set=True, code=200) conn.read.return_value = (msg, mock.Mock()) # type: ignore[attr-defined] return conn timeout = client.ClientTimeout(total=10, connect=0, sock_connect=0, sock_read=0) async with ClientSession(timeout=timeout) as session: with mock.patch.object( session._connector, "_create_connection", create_connection ): try: resp = await session.get("http://example.com") except asyncio.TimeoutError: # pragma: no cover pytest.fail("0 should disable timeout.") resp.close() async def test_client_session_timeout_bad_argument() -> None: with pytest.raises(ValueError): ClientSession(timeout="test_bad_argumnet") # type: ignore[arg-type] with pytest.raises(ValueError): ClientSession(timeout=100) # type: ignore[arg-type] async def test_requote_redirect_url_default() -> None: session = ClientSession() assert session.requote_redirect_url await session.close() async def test_requote_redirect_url_default_disable() -> None: session = ClientSession(requote_redirect_url=False) assert not session.requote_redirect_url await session.close() @pytest.mark.parametrize( ("base_url", "url", "expected_url"), [ pytest.param( None, "http://example.com/test", URL("http://example.com/test"), id="base_url=None url='http://example.com/test'", ), pytest.param( None, URL("http://example.com/test"), URL("http://example.com/test"), id="base_url=None url=URL('http://example.com/test')", ), pytest.param( "http://example.com", "/test", URL("http://example.com/test"), id="base_url='http://example.com' url='/test'", ), pytest.param( URL("http://example.com"), "/test", URL("http://example.com/test"), id="base_url=URL('http://example.com') url='/test'", ), pytest.param( URL("http://example.com/test1/"), "test2", URL("http://example.com/test1/test2"), id="base_url=URL('http://example.com/test1/') url='test2'", ), pytest.param( URL("http://example.com/test1/"), "/test2", URL("http://example.com/test2"), id="base_url=URL('http://example.com/test1/') url='/test2'", ), pytest.param( URL("http://example.com/test1/"), "test2?q=foo#bar", URL("http://example.com/test1/test2?q=foo#bar"), id="base_url=URL('http://example.com/test1/') url='test2?q=foo#bar'", ), pytest.param( URL("http://example.com/test1/"), "http://foo.com/bar", URL("http://foo.com/bar"), id="base_url=URL('http://example.com/test1/') url='http://foo.com/bar'", ), pytest.param( URL("http://example.com"), "http://foo.com/bar", URL("http://foo.com/bar"), id="base_url=URL('http://example.com') url='http://foo.com/bar'", ), pytest.param( URL("http://example.com/test1/"), "http://foo.com", URL("http://foo.com"), id="base_url=URL('http://example.com/test1/') url='http://foo.com'", ), ], ) async def test_build_url_returns_expected_url( # type: ignore[misc] create_session: Callable[..., Awaitable[ClientSession]], base_url: URL | str | None, url: URL | str, expected_url: URL, ) -> None: session = await create_session(base_url) assert session._build_url(url) == expected_url async def test_base_url_without_trailing_slash() -> None: with pytest.raises(ValueError, match="base_url must have a trailing '/'"): ClientSession(base_url="http://example.com/test") async def test_instantiation_with_invalid_timeout_value( loop: asyncio.AbstractEventLoop, ) -> None: loop.set_debug(False) logs = [] loop.set_exception_handler(lambda loop, ctx: logs.append(ctx)) with pytest.raises(ValueError, match="timeout parameter cannot be .*"): ClientSession(timeout=1) # type: ignore[arg-type] # should not have "Unclosed client session" warning assert not logs @pytest.mark.parametrize( ("outer_name", "inner_name"), [ ("skip_auto_headers", "_skip_auto_headers"), ("auth", "_default_auth"), ("json_serialize", "_json_serialize"), ("connector_owner", "_connector_owner"), ("raise_for_status", "_raise_for_status"), ("trust_env", "_trust_env"), ("trace_configs", "_trace_configs"), ], ) async def test_properties( session: ClientSession, outer_name: str, inner_name: str ) -> None: value = uuid4() setattr(session, inner_name, value) assert value == getattr(session, outer_name) @pytest.mark.usefixtures("netrc_default_contents") async def test_netrc_auth_with_trust_env(auth_server: TestServer) -> None: """Test that netrc authentication works with ClientSession when NETRC env var is set.""" async with ( ClientSession(trust_env=True) as session, session.get(auth_server.make_url("/")) as resp, ): text = await resp.text() # Base64 encoded "netrc_user:netrc_pass" is "bmV0cmNfdXNlcjpuZXRyY19wYXNz" assert text == "auth:Basic bmV0cmNfdXNlcjpuZXRyY19wYXNz" @pytest.mark.usefixtures("netrc_default_contents") async def test_netrc_auth_skipped_without_trust_env(auth_server: TestServer) -> None: """Test that netrc authentication is skipped when trust_env=False.""" async with ( ClientSession(trust_env=False) as session, session.get(auth_server.make_url("/")) as resp, ): text = await resp.text() assert text == "no_auth" @pytest.mark.usefixtures("no_netrc") async def test_netrc_auth_skipped_without_netrc_file(auth_server: TestServer) -> None: """Test that netrc authentication is skipped when no netrc file exists.""" async with ( ClientSession(trust_env=True) as session, session.get(auth_server.make_url("/")) as resp, ): text = await resp.text() assert text == "no_auth" @pytest.mark.usefixtures("netrc_home_directory") async def test_netrc_auth_from_home_directory(auth_server: TestServer) -> None: """Test that netrc authentication works from default ~/.netrc location without NETRC env var.""" async with ( ClientSession(trust_env=True) as session, session.get(auth_server.make_url("/")) as resp, ): text = await resp.text() assert text == "auth:Basic bmV0cmNfdXNlcjpuZXRyY19wYXNz" @pytest.mark.usefixtures("netrc_default_contents") async def test_netrc_auth_overridden_by_explicit_auth(auth_server: TestServer) -> None: """Test that explicit auth parameter overrides netrc authentication.""" async with ( ClientSession(trust_env=True) as session, session.get( auth_server.make_url("/"), auth=aiohttp.BasicAuth("explicit_user", "explicit_pass"), ) as resp, ): text = await resp.text() # Base64 encoded "explicit_user:explicit_pass" is "ZXhwbGljaXRfdXNlcjpleHBsaWNpdF9wYXNz" assert text == "auth:Basic ZXhwbGljaXRfdXNlcjpleHBsaWNpdF9wYXNz" @pytest.mark.usefixtures("netrc_other_host") async def test_netrc_auth_host_not_in_netrc(auth_server: TestServer) -> None: """Test that netrc lookup returns None when host is not in netrc file.""" async with ( ClientSession(trust_env=True) as session, session.get(auth_server.make_url("/")) as resp, ): text = await resp.text() # Should not have auth since the host is not in netrc assert text == "no_auth" ================================================ FILE: tests/test_client_ws.py ================================================ import asyncio import base64 import hashlib import os from collections.abc import Mapping from unittest import mock import pytest import aiohttp from aiohttp import ( ClientConnectionResetError, ClientWSTimeout, ServerDisconnectedError, client, hdrs, ) from aiohttp._websocket.writer import WebSocketWriter as RealWebSocketWriter from aiohttp.http import WS_KEY from aiohttp.http_websocket import WSMessageClose from aiohttp.streams import EofStream async def test_ws_connect( ws_key: str, loop: asyncio.AbstractEventLoop, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) assert isinstance(res, client.ClientWebSocketResponse) assert res.protocol == "chat" assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] async def test_ws_connect_read_timeout_is_reset_to_inf( ws_key: str, loop: asyncio.AbstractEventLoop, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } resp.connection.protocol.read_timeout = 0.5 with ( mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req, ): m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) assert isinstance(res, client.ClientWebSocketResponse) assert res.protocol == "chat" assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] assert resp.connection.protocol.read_timeout is None async def test_ws_connect_read_timeout_stays_inf( ws_key: str, loop: asyncio.AbstractEventLoop, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } resp.connection.protocol.read_timeout = None with ( mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req, ): m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat"), timeout=ClientWSTimeout(0.5), ) assert isinstance(res, client.ClientWebSocketResponse) assert res.protocol == "chat" assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] assert resp.connection.protocol.read_timeout is None async def test_ws_connect_read_timeout_reset_to_max( ws_key: str, loop: asyncio.AbstractEventLoop, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } resp.connection.protocol.read_timeout = 0.5 with ( mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req, ): m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat"), timeout=ClientWSTimeout(1.0), ) assert isinstance(res, client.ClientWebSocketResponse) assert res.protocol == "chat" assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] assert resp.connection.protocol.read_timeout == 1.0 async def test_ws_connect_with_origin( key_data: bytes, loop: asyncio.AbstractEventLoop ) -> None: resp = mock.Mock() resp.status = 403 with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) origin = "https://example.org/page.html" with pytest.raises(client.WSServerHandshakeError): await aiohttp.ClientSession().ws_connect( "http://test.org", origin=origin ) assert hdrs.ORIGIN in m_req.call_args[1]["headers"] assert m_req.call_args[1]["headers"][hdrs.ORIGIN] == origin async def test_ws_connect_with_params( ws_key: str, loop: asyncio.AbstractEventLoop, key_data: bytes ) -> None: params = {"key1": "value1", "key2": "value2"} resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat"), params=params ) assert m_req.call_args[1]["params"] == params async def test_ws_connect_custom_response( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: class CustomResponse(client.ClientWebSocketResponse): def read(self, decode: bool = False) -> str: return "customized!" resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession( ws_response_class=CustomResponse ).ws_connect("http://test.org") # TODO(PY313): Use TypeVar(default=) to make ClientSession Generic over response classes. # Then .ws_connect() can return CustomResponse here. assert res.read() == "customized!" # type: ignore[attr-defined] async def test_ws_connect_err_status( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 500 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError) as ctx: await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) assert ctx.value.message == "Invalid response status" async def test_ws_connect_err_upgrade( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "test", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError) as ctx: await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) assert ctx.value.message == "Invalid upgrade header" async def test_ws_connect_err_conn( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "close", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError) as ctx: await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) assert ctx.value.message == "Invalid connection header" async def test_ws_connect_err_challenge( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: "asdfasdfasdfasdfasdfasdf", } with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError) as ctx: await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) assert ctx.value.message == "Invalid challenge response" async def test_ws_connect_common_headers( ws_key: str, loop: asyncio.AbstractEventLoop, key_data: bytes ) -> None: # Emulate a headers dict being reused for a second ws_connect. # In this scenario, we need to ensure that the newly generated secret key # is sent to the server, not the stale key. headers: dict[str, str] = {} async def test_connection() -> None: async def mock_get( *args: object, headers: Mapping[str, str], **kwargs: object ) -> mock.Mock: resp = mock.Mock() resp.status = 101 key = headers[hdrs.SEC_WEBSOCKET_KEY] accept = base64.b64encode( hashlib.sha1(base64.b64encode(base64.b64decode(key)) + WS_KEY).digest() ).decode() resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: accept, hdrs.SEC_WEBSOCKET_PROTOCOL: "chat", } resp.connection.protocol.read_timeout = None return resp with mock.patch("aiohttp.client.os") as m_os: with mock.patch( "aiohttp.client.ClientSession.request", side_effect=mock_get ) as m_req: m_os.urandom.return_value = key_data res = await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat"), headers=headers ) assert isinstance(res, client.ClientWebSocketResponse) assert res.protocol == "chat" assert hdrs.ORIGIN not in m_req.call_args[1]["headers"] await test_connection() # Generate a new ws key key_data = os.urandom(16) await test_connection() async def test_close( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: mresp = mock.Mock() mresp.status = 101 mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) writer = mock.create_autospec( RealWebSocketWriter, instance=True, spec_set=True ) WebSocketWriter.return_value = writer session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") assert not resp.closed resp._reader.feed_data(WSMessageClose(data=0, size=0, extra="")) res = await resp.close() writer.close.assert_called_with(1000, b"") assert resp.closed assert res # type: ignore[unreachable] assert resp.exception() is None # idempotent res = await resp.close() assert not res assert writer.close.call_count == 1 await session.close() async def test_close_eofstream( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: mresp = mock.Mock() mresp.status = 101 mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) writer = WebSocketWriter.return_value = mock.Mock() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") assert not resp.closed exc = EofStream() resp._reader.set_exception(exc) await resp.receive() writer.close.assert_called_with(1000, b"") assert resp.closed await session.close() # type: ignore[unreachable] async def test_close_connection_lost( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: """Test the websocket client handles the connection being closed out from under it.""" mresp = mock.Mock(spec_set=client.ClientResponse) mresp.status = 101 mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } mresp.connection.protocol.read_timeout = None with ( mock.patch("aiohttp.client.WebSocketWriter"), mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req, ): m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") assert not resp.closed exc = ServerDisconnectedError() resp._reader.set_exception(exc) msg = await resp.receive() assert msg.type is aiohttp.WSMsgType.CLOSED assert resp.closed await session.close() # type: ignore[unreachable] async def test_close_exc( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: mresp = mock.Mock() mresp.status = 101 mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) writer = mock.create_autospec( RealWebSocketWriter, instance=True, spec_set=True ) WebSocketWriter.return_value = writer session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") assert not resp.closed exc = ValueError() resp._reader.set_exception(exc) await resp.close() assert resp.closed assert resp.exception() is exc # type: ignore[unreachable] await session.close() async def test_close_exc2( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: mresp = mock.Mock() mresp.status = 101 mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) writer = WebSocketWriter.return_value = mock.Mock() resp = await aiohttp.ClientSession().ws_connect("http://test.org") assert not resp.closed exc = ValueError() writer.close.side_effect = exc await resp.close() assert resp.closed assert resp.exception() is exc # type: ignore[unreachable] resp._closed = False writer.close.side_effect = asyncio.CancelledError() with pytest.raises(asyncio.CancelledError): await resp.close() @pytest.mark.parametrize("exc", (ClientConnectionResetError, ConnectionResetError)) async def test_send_data_after_close( exc: type[Exception], ws_key: str, key_data: bytes, loop: asyncio.AbstractEventLoop, ) -> None: mresp = mock.Mock() mresp.status = 101 mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) resp = await aiohttp.ClientSession().ws_connect("http://test.org") resp._writer._closing = True for meth, args in ( (resp.ping, ()), (resp.pong, ()), (resp.send_str, ("s",)), (resp.send_bytes, (b"b",)), (resp.send_json, ({},)), (resp.send_frame, (b"", aiohttp.WSMsgType.BINARY)), ): with pytest.raises(exc): # Verify exc can be caught with both classes await meth(*args) async def test_send_data_type_errors( ws_key: str, key_data: bytes, loop: asyncio.AbstractEventLoop ) -> None: mresp = mock.Mock() mresp.status = 101 mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) WebSocketWriter.return_value = mock.Mock() resp = await aiohttp.ClientSession().ws_connect("http://test.org") with pytest.raises(TypeError): await resp.send_str(b"s") # type: ignore[arg-type] with pytest.raises(TypeError): await resp.send_bytes("b") # type: ignore[arg-type] with pytest.raises(TypeError): await resp.send_json(set()) async def test_reader_read_exception( ws_key: str, key_data: bytes, loop: asyncio.AbstractEventLoop ) -> None: hresp = mock.Mock() hresp.status = 101 hresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } hresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(hresp) writer = mock.create_autospec( RealWebSocketWriter, instance=True, spec_set=True ) WebSocketWriter.return_value = writer session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") exc = ValueError() resp._reader.set_exception(exc) msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.ERROR assert resp.exception() is exc await session.close() async def test_receive_runtime_err(loop: asyncio.AbstractEventLoop) -> None: resp = client.ClientWebSocketResponse( mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock(), ClientWSTimeout(ws_receive=10.0), True, True, loop, ) resp._waiting = True with pytest.raises(RuntimeError): await resp.receive() async def test_heartbeat_reset_coalesces_on_data( loop: asyncio.AbstractEventLoop, ) -> None: response = mock.Mock() response.connection = None resp = client.ClientWebSocketResponse( mock.Mock(), mock.Mock(), None, response, ClientWSTimeout(ws_receive=10.0), True, True, loop, heartbeat=0.05, ) with mock.patch.object(resp, "_reset_heartbeat", autospec=True) as reset: resp._on_data_received() resp._on_data_received() await asyncio.sleep(0) assert reset.call_count == 1 async def test_receive_does_not_reset_heartbeat( loop: asyncio.AbstractEventLoop, ) -> None: response = mock.Mock() response.connection = None msg = mock.Mock(type=aiohttp.WSMsgType.TEXT) reader = mock.Mock() reader.read = mock.AsyncMock(return_value=msg) resp = client.ClientWebSocketResponse( reader, mock.Mock(), None, response, ClientWSTimeout(ws_receive=10.0), True, True, loop, heartbeat=0.05, ) with mock.patch.object(resp, "_reset_heartbeat", autospec=True) as reset: received = await resp.receive() assert received is msg reset.assert_not_called() async def test_cancel_heartbeat_cancels_pending_heartbeat_reset_handle( loop: asyncio.AbstractEventLoop, ) -> None: response = mock.Mock() response.connection = None resp = client.ClientWebSocketResponse( mock.Mock(), mock.Mock(), None, response, ClientWSTimeout(ws_receive=10.0), True, True, loop, heartbeat=0.05, ) resp._on_data_received() handle = resp._heartbeat_reset_handle assert handle is not None resp._cancel_heartbeat() assert resp._heartbeat_reset_handle is None assert resp._need_heartbeat_reset is False assert handle.cancelled() async def test_flush_heartbeat_reset_returns_early_when_not_needed( loop: asyncio.AbstractEventLoop, ) -> None: response = mock.Mock() response.connection = None resp = client.ClientWebSocketResponse( mock.Mock(), mock.Mock(), None, response, ClientWSTimeout(ws_receive=10.0), True, True, loop, heartbeat=0.05, ) resp._need_heartbeat_reset = False with mock.patch.object(resp, "_reset_heartbeat", autospec=True) as reset: resp._flush_heartbeat_reset() reset.assert_not_called() async def test_send_heartbeat_returns_early_when_reset_is_pending( loop: asyncio.AbstractEventLoop, ) -> None: response = mock.Mock() response.connection = None writer = mock.Mock() resp = client.ClientWebSocketResponse( mock.Mock(), writer, None, response, ClientWSTimeout(ws_receive=10.0), True, True, loop, heartbeat=0.05, ) resp._need_heartbeat_reset = True resp._send_heartbeat() writer.send_frame.assert_not_called() async def test_ws_connect_close_resp_on_err( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 500 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError): await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) resp.close.assert_called_with() async def test_ws_connect_non_overlapped_protocols( ws_key: str, loop: asyncio.AbstractEventLoop, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) assert res.protocol is None async def test_ws_connect_non_overlapped_protocols_2( ws_key: str, loop: asyncio.AbstractEventLoop, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another", } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) connector = aiohttp.TCPConnector(force_close=True) res = await aiohttp.ClientSession(connector=connector).ws_connect( "http://test.org", protocols=("t1", "t2", "chat") ) assert res.protocol is None del res async def test_ws_connect_deflate( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", compress=15 ) assert res.compress == 15 assert res.client_notakeover is False async def test_ws_connect_deflate_per_message( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: mresp = mock.Mock() mresp.status = 101 mresp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate", } mresp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter: with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) writer = mock.create_autospec( RealWebSocketWriter, instance=True, spec_set=True ) WebSocketWriter.return_value = writer session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") await resp.send_str("string", compress=-1) writer.send_frame.assert_called_with( b"string", aiohttp.WSMsgType.TEXT, compress=-1 ) await resp.send_bytes(b"bytes", compress=15) writer.send_frame.assert_called_with( b"bytes", aiohttp.WSMsgType.BINARY, compress=15 ) await resp.send_json([{}], compress=-9) writer.send_frame.assert_called_with( b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9 ) await resp.send_frame(b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9) writer.send_frame.assert_called_with( b"[{}]", aiohttp.WSMsgType.TEXT, -9 ) await session.close() async def test_ws_connect_deflate_server_not_support( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", compress=15 ) assert res.compress == 0 assert res.client_notakeover is False async def test_ws_connect_deflate_notakeover( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_no_context_takeover", } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", compress=15 ) assert res.compress == 15 assert res.client_notakeover is True async def test_ws_connect_deflate_client_wbits( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_max_window_bits=10", } resp.connection.protocol.read_timeout = None with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) res = await aiohttp.ClientSession().ws_connect( "http://test.org", compress=15 ) assert res.compress == 10 assert res.client_notakeover is False async def test_ws_connect_deflate_client_wbits_bad( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; " "client_max_window_bits=6", } with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError): await aiohttp.ClientSession().ws_connect("http://test.org", compress=15) async def test_ws_connect_deflate_server_ext_bad( loop: asyncio.AbstractEventLoop, ws_key: str, key_data: bytes ) -> None: resp = mock.Mock() resp.status = 101 resp.headers = { hdrs.UPGRADE: "websocket", hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; bad", } with mock.patch("aiohttp.client.os") as m_os: with mock.patch("aiohttp.client.ClientSession.request") as m_req: m_os.urandom.return_value = key_data m_req.return_value = loop.create_future() m_req.return_value.set_result(resp) with pytest.raises(client.WSServerHandshakeError): await aiohttp.ClientSession().ws_connect("http://test.org", compress=15) ================================================ FILE: tests/test_client_ws_functional.py ================================================ import asyncio import json import struct import sys from contextlib import suppress from typing import Literal, NoReturn from unittest import mock import pytest import aiohttp from aiohttp import ( ClientConnectionResetError, ServerTimeoutError, WSMessageTypeError, WSMsgType, hdrs, web, ) from aiohttp._websocket.models import WSMessageBinary from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.client_ws import ClientWSTimeout from aiohttp.http import WSCloseCode from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer if sys.version_info >= (3, 11): import asyncio as async_timeout else: import async_timeout class PatchableWebSocketDataQueue(WebSocketDataQueue): """A WebSocketDataQueue that can be patched.""" async def test_send_recv_text(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() await ws.send_str(msg + "/answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("ask") assert resp.get_extra_info("socket") is not None data = await resp.receive_str() assert data == "ask/answer" await resp.close() assert resp.get_extra_info("socket") is None async def test_send_recv_bytes_bad_type(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() await ws.send_str(msg + "/answer") await ws.close() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("ask") with pytest.raises(WSMessageTypeError): await resp.receive_bytes() await resp.close() async def test_recv_bytes_after_close(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) await ws.close() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") with pytest.raises( WSMessageTypeError, match=f"Received message {WSMsgType.CLOSE}:.+ is not WSMsgType.BINARY", ): await resp.receive_bytes() await resp.close() async def test_send_recv_bytes(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_bytes() await ws.send_bytes(msg + b"/answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_bytes(b"ask") data = await resp.receive_bytes() assert data == b"ask/answer" await resp.close() async def test_send_recv_text_bad_type(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_bytes() await ws.send_bytes(msg + b"/answer") await ws.close() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_bytes(b"ask") with pytest.raises(WSMessageTypeError): await resp.receive_str() await resp.close() async def test_recv_text_after_close(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) await ws.close() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") with pytest.raises( WSMessageTypeError, match=f"Received message {WSMsgType.CLOSE}:.+ is not WSMsgType.TEXT", ): await resp.receive_str() await resp.close() async def test_send_recv_json(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) data = await ws.receive_json() await ws.send_json({"response": data["request"]}) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") payload = {"request": "test"} await resp.send_json(payload) data = await resp.receive_json() assert data["response"] == payload["request"] await resp.close() async def test_send_recv_json_bytes(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.send_bytes(json.dumps({"response": "x"}).encode()) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") data = await resp.receive() assert isinstance(data, WSMessageBinary) assert data.json() == {"response": "x"} await resp.close() async def test_send_json_bytes_client(aiohttp_client: AiohttpClient) -> None: """Test ClientWebSocketResponse.send_json_bytes sends binary frame.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.BINARY data = json.loads(msg.data) await ws.send_json_bytes( {"response": data["request"]}, dumps=lambda x: json.dumps(x).encode("utf-8"), ) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") test_payload = {"request": "test"} await resp.send_json_bytes( test_payload, dumps=lambda x: json.dumps(x).encode("utf-8") ) msg = await resp.receive() assert msg.type is WSMsgType.BINARY data = json.loads(msg.data) assert data["response"] == test_payload["request"] await resp.close() async def test_send_json_bytes_custom_encoder(aiohttp_client: AiohttpClient) -> None: """Test send_json_bytes with custom bytes-returning encoder.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.BINARY # Custom encoder uses compact separators assert msg.data == b'{"test":"value"}' await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_json_bytes( {"test": "value"}, dumps=lambda x: json.dumps(x, separators=(",", ":")).encode("utf-8"), ) await resp.close() async def test_send_recv_frame(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.BINARY await ws.send_frame(msg.data, msg.type) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_frame(b"test", WSMsgType.BINARY) data = await resp.receive() assert data.data == b"test" assert data.type is WSMsgType.BINARY await resp.close() async def test_ping_pong(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_bytes() await ws.ping() await ws.send_bytes(msg + b"/answer") try: await ws.close() finally: closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.ping() await resp.send_bytes(b"ask") msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.BINARY assert msg.data == b"ask/answer" msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE await resp.close() await closed async def test_ping_pong_manual(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_bytes() await ws.ping() await ws.send_bytes(msg + b"/answer") try: await ws.close() finally: closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", autoping=False) await resp.ping() await resp.send_bytes(b"ask") msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.PONG msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.PING await resp.pong() msg = await resp.receive() assert msg.data == b"ask/answer" msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE await closed async def test_close(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_bytes() await ws.send_str("test") await ws.receive() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_bytes(b"ask") closed = await resp.close() assert closed assert resp.closed assert resp.close_code == 1000 msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSED async def test_concurrent_task_close(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.ws_connect("/") as resp: # wait for the message in a separate task task = asyncio.create_task(resp.receive()) # Make sure we start to wait on receiving message before closing the connection await asyncio.sleep(0.1) closed = await resp.close() await task assert closed assert resp.closed assert resp.close_code == 1000 async def test_concurrent_close(aiohttp_client: AiohttpClient) -> None: client_ws: aiohttp.ClientWebSocketResponse | None = None async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_bytes() await ws.send_str("test") assert client_ws is not None await client_ws.close() msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.CLOSE return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = client_ws = await client.ws_connect("/") await ws.send_bytes(b"ask") msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.CLOSING await asyncio.sleep(0.01) msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.CLOSED async def test_concurrent_close_multiple_tasks(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_bytes() await ws.send_str("test") msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.CLOSE return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.send_bytes(b"ask") task1 = asyncio.create_task(ws.close()) task2 = asyncio.create_task(ws.close()) msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.CLOSED await task1 await task2 msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.CLOSED async def test_close_from_server(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) try: await ws.receive_bytes() await ws.close() finally: closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_bytes(b"ask") msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert resp.closed msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSED await closed async def test_close_manual(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_bytes() await ws.send_str("test") try: await ws.close() finally: closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", autoclose=False) await resp.send_bytes(b"ask") msg = await resp.receive() assert msg.data == "test" msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.data == 1000 assert msg.extra == "" assert not resp.closed await resp.close() await closed assert resp.closed async def test_close_timeout_sock_close_read(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_bytes() await ws.send_str("test") await asyncio.sleep(1) assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) timeout = ClientWSTimeout(ws_close=0.2) resp = await client.ws_connect("/", timeout=timeout, autoclose=False) await resp.send_bytes(b"ask") msg = await resp.receive() assert msg.data == "test" assert msg.type == aiohttp.WSMsgType.TEXT await resp.close() assert resp.closed assert isinstance(resp.exception(), asyncio.TimeoutError) async def test_close_timeout_deprecated(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_bytes() await ws.send_str("test") await asyncio.sleep(1) assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) with pytest.warns( DeprecationWarning, match="parameter 'timeout' of type 'float' " "is deprecated, please use " r"'timeout=ClientWSTimeout\(ws_close=...\)'", ): resp = await client.ws_connect("/", timeout=0.2, autoclose=False) await resp.send_bytes(b"ask") msg = await resp.receive() assert msg.data == "test" assert msg.type == aiohttp.WSMsgType.TEXT await resp.close() assert resp.closed assert isinstance(resp.exception(), asyncio.TimeoutError) async def test_close_cancel(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_bytes() await ws.send_str("test") await asyncio.sleep(10) assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", autoclose=False) await resp.send_bytes(b"ask") text = await resp.receive() assert text.data == "test" t = loop.create_task(resp.close()) await asyncio.sleep(0.1) t.cancel() await asyncio.sleep(0.1) assert resp.closed assert resp.exception() is None async def test_override_default_headers(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: assert request.headers[hdrs.SEC_WEBSOCKET_VERSION] == "8" ws = web.WebSocketResponse() await ws.prepare(request) await ws.send_str("answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) headers = {hdrs.SEC_WEBSOCKET_VERSION: "8"} client = await aiohttp_client(app) resp = await client.ws_connect("/", headers=headers) msg = await resp.receive() assert msg.data == "answer" await resp.close() async def test_additional_headers(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: assert request.headers["x-hdr"] == "xtra" ws = web.WebSocketResponse() await ws.prepare(request) await ws.send_str("answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", headers={"x-hdr": "xtra"}) msg = await resp.receive() assert msg.data == "answer" await resp.close() async def test_recv_protocol_error(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_str() assert ws._writer is not None ws._writer.transport.write(b"01234" * 100) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("ask") msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.ERROR assert type(msg.data) is aiohttp.WebSocketError assert msg.data.code == aiohttp.WSCloseCode.PROTOCOL_ERROR assert str(msg.data) == "Received frame with non-zero reserved bits" assert msg.extra is None await resp.close() async def test_recv_timeout(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive_str() await asyncio.sleep(0.1) assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("ask") with pytest.raises(asyncio.TimeoutError): async with async_timeout.timeout(0.01): await resp.receive() await resp.close() async def test_receive_timeout_sock_read(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) receive_timeout = ClientWSTimeout(ws_receive=0.1) resp = await client.ws_connect("/", timeout=receive_timeout) with pytest.raises(asyncio.TimeoutError): await resp.receive(timeout=0.05) await resp.close() async def test_receive_timeout_deprecation(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) with pytest.warns( DeprecationWarning, match="float parameter 'receive_timeout' " "is deprecated, please use parameter " r"'timeout=ClientWSTimeout\(ws_receive=...\)'", ): resp = await client.ws_connect("/", receive_timeout=0.1) with pytest.raises(asyncio.TimeoutError): await resp.receive(timeout=0.05) await resp.close() async def test_custom_receive_timeout(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") with pytest.raises(asyncio.TimeoutError): await resp.receive(0.05) await resp.close() async def test_heartbeat(aiohttp_client: AiohttpClient) -> None: ping_received = False async def handler(request: web.Request) -> NoReturn: nonlocal ping_received ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.PING ping_received = True await ws.close() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", heartbeat=0.01) await asyncio.sleep(0.1) await resp.receive() await resp.close() assert ping_received async def test_heartbeat_connection_closed(aiohttp_client: AiohttpClient) -> None: """Test that the connection is closed while ping is in progress.""" async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) await ws.receive() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", heartbeat=0.1) ping_count = 0 # We patch write here to simulate a connection reset error # since if we closed the connection normally, the client would # would cancel the heartbeat task and we wouldn't get a ping assert resp._conn is not None with ( mock.patch.object( resp._conn.transport, "write", side_effect=ClientConnectionResetError ), mock.patch.object( resp._writer, "send_frame", wraps=resp._writer.send_frame ) as send_frame, ): await resp.receive() ping_count = send_frame.call_args_list.count(mock.call(b"", WSMsgType.PING)) # Connection should be closed roughly after 1.5x heartbeat. await asyncio.sleep(0.2) assert ping_count == 1 assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE async def test_heartbeat_no_pong(aiohttp_client: AiohttpClient) -> None: """Test that the connection is closed if no pong is received without sending messages.""" ping_received = False async def handler(request: web.Request) -> NoReturn: nonlocal ping_received ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.PING ping_received = True await ws.receive() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", heartbeat=0.1) # Connection should be closed roughly after 1.5x heartbeat. await asyncio.sleep(0.2) assert ping_received assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE async def test_heartbeat_does_not_timeout_while_receiving_large_frame( aiohttp_client: AiohttpClient, ) -> None: """Slowly receiving a single large frame should not trip heartbeat. Regression test for the behavior described in https://github.com/aio-libs/aiohttp/discussions/12023: on slow connections, the websocket heartbeat used to be reset only after a full message was read, which could cause a ping/pong timeout while bytes were still being received. """ payload = b"x" * 2048 heartbeat = 0.1 chunk_size = 64 delay = 0.01 async def handler(request: web.Request) -> web.WebSocketResponse: # Disable auto-PONG so a heartbeat PING during frame streaming would # surface as a timeout/closure on the client side. ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) assert ws._writer is not None transport = ws._writer.transport # Server-to-client frames are not masked. length = len(payload) # payload is fixed length of 2048 bytes header = bytes((0x82, 126)) + struct.pack("!H", length) frame = header + payload for i in range(0, len(frame), chunk_size): transport.write(frame[i : i + chunk_size]) await asyncio.sleep(delay) # Ensure the server side is cleaned up. with suppress(asyncio.TimeoutError): await ws.receive(timeout=1.0) with suppress(Exception): await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.ws_connect("/", heartbeat=heartbeat) as resp: # If heartbeat were not reset on incoming bytes, the client would send # a PING while this frame is still being streamed. with mock.patch.object( resp._writer, "send_frame", wraps=resp._writer.send_frame ) as sf: msg = await resp.receive() assert ( sf.call_args_list.count(mock.call(b"", WSMsgType.PING)) == 0 ), "Heartbeat PING sent while data was still being received" assert msg.type is WSMsgType.BINARY assert msg.data == payload assert not resp.closed async def test_heartbeat_no_pong_after_receive_many_messages( aiohttp_client: AiohttpClient, ) -> None: """Test that the connection is closed if no pong is received after receiving many messages.""" ping_received = False async def handler(request: web.Request) -> NoReturn: nonlocal ping_received ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) for _ in range(5): await ws.send_str("test") await asyncio.sleep(0.05) for _ in range(5): await ws.send_str("test") msg = await ws.receive() ping_received = msg.type is aiohttp.WSMsgType.PING await ws.receive() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", heartbeat=0.1) for _ in range(10): test_msg = await resp.receive() assert test_msg.data == "test" # Connection should be closed roughly after 1.5x heartbeat. await asyncio.sleep(0.2) assert ping_received assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE async def test_heartbeat_no_pong_after_send_many_messages( aiohttp_client: AiohttpClient, ) -> None: """Test that the connection is closed if no pong is received after sending many messages.""" ping_received = False async def handler(request: web.Request) -> NoReturn: nonlocal ping_received ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) for _ in range(10): msg = await ws.receive() assert msg.data == "test" assert msg.type is aiohttp.WSMsgType.TEXT msg = await ws.receive() ping_received = msg.type is aiohttp.WSMsgType.PING await ws.receive() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", heartbeat=0.1) for _ in range(5): await resp.send_str("test") await asyncio.sleep(0.05) for _ in range(5): await resp.send_str("test") # Connection should be closed roughly after 1.5x heartbeat. await asyncio.sleep(0.2) assert ping_received assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE async def test_heartbeat_no_pong_concurrent_receive( aiohttp_client: AiohttpClient, ) -> None: ping_received = False async def handler(request: web.Request) -> NoReturn: nonlocal ping_received ws = web.WebSocketResponse(autoping=False) with mock.patch( "aiohttp.web_ws.WebSocketDataQueue", PatchableWebSocketDataQueue ): await ws.prepare(request) msg = await ws.receive() ping_received = msg.type is aiohttp.WSMsgType.PING with mock.patch.object( ws._reader, "feed_eof", autospec=True, spec_set=True, return_value=None ): await asyncio.sleep(10.0) assert False app = web.Application() app.router.add_route("GET", "/", handler) with mock.patch("aiohttp.client.WebSocketDataQueue", PatchableWebSocketDataQueue): client = await aiohttp_client(app) resp = await client.ws_connect("/", heartbeat=0.1) with mock.patch.object( resp._reader, "feed_eof", autospec=True, spec_set=True, return_value=None ): # Connection should be closed roughly after 1.5x heartbeat. msg = await resp.receive(5.0) assert ping_received assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE assert msg.type is WSMsgType.ERROR assert isinstance(msg.data, ServerTimeoutError) assert str(msg.data) == "No PONG received after 0.05 seconds" async def test_close_websocket_while_ping_inflight( aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop ) -> None: """Test closing the websocket while a ping is in-flight.""" ping_received = False async def handler(request: web.Request) -> NoReturn: nonlocal ping_received ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.BINARY msg = await ws.receive() ping_received = msg.type is aiohttp.WSMsgType.PING await ws.receive() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", heartbeat=0.1) await resp.send_bytes(b"ask") cancelled = False ping_started = loop.create_future() async def delayed_send_frame( message: bytes, opcode: int, compress: int | None = None ) -> None: assert opcode == WSMsgType.PING nonlocal cancelled ping_started.set_result(None) try: await asyncio.sleep(1) except asyncio.CancelledError: cancelled = True raise with mock.patch.object(resp._writer, "send_frame", delayed_send_frame): async with async_timeout.timeout(1): await ping_started await resp.close() await asyncio.sleep(0) assert ping_started.result() is None assert cancelled is True @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_send_recv_compress(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() await ws.send_str(msg + "/answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", compress=15) await resp.send_str("ask") assert resp.compress == 15 data = await resp.receive_str() assert data == "ask/answer" await resp.close() assert resp.get_extra_info("socket") is None @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_send_recv_compress_wbits(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() await ws.send_str(msg + "/answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/", compress=9) await resp.send_str("ask") # Client indicates supports wbits 15 # Server supports wbit 15 for decode assert resp.compress == 15 data = await resp.receive_str() assert data == "ask/answer" await resp.close() assert resp.get_extra_info("socket") is None async def test_send_recv_compress_wbit_error(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) with pytest.raises(ValueError): await client.ws_connect("/", compress=1) async def test_ws_client_async_for(aiohttp_client: AiohttpClient) -> None: items = ["q1", "q2", "q3"] async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) for i in items: await ws.send_str(i) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") it = iter(items) async for msg in resp: assert msg.data == next(it) with pytest.raises(StopIteration): next(it) assert resp.closed async def test_ws_async_with(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.TEXT await ws.send_str(msg.data + "/answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as client: async with client.ws_connect(server.make_url("/")) as ws: await ws.send_str("request") msg = await ws.receive() assert msg.data == "request/answer" assert ws.closed async def test_ws_async_with_send(aiohttp_server: AiohttpServer) -> None: # send_xxx methods have to return awaitable objects async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.TEXT await ws.send_str(msg.data + "/answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as client: async with client.ws_connect(server.make_url("/")) as ws: await ws.send_str("request") msg = await ws.receive() assert msg.data == "request/answer" assert ws.closed async def test_ws_async_with_shortcut(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.TEXT await ws.send_str(msg.data + "/answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as client: async with client.ws_connect(server.make_url("/")) as ws: await ws.send_str("request") msg = await ws.receive() assert msg.data == "request/answer" assert ws.closed async def test_closed_async_for(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) try: await ws.send_bytes(b"started") await ws.receive_bytes() finally: closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") messages = [] async for msg in resp: messages.append(msg) assert b"started" == msg.data await resp.send_bytes(b"ask") await resp.close() assert 1 == len(messages) assert messages[0].type == aiohttp.WSMsgType.BINARY assert messages[0].data == b"started" assert resp.closed await closed async def test_peer_connection_lost(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() assert msg == "ask" await ws.send_str("answer") assert request.transport is not None request.transport.close() await asyncio.sleep(10) assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("ask") assert "answer" == await resp.receive_str() msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSED await resp.close() async def test_peer_connection_lost_iter(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() assert msg == "ask" await ws.send_str("answer") assert request.transport is not None request.transport.close() await asyncio.sleep(100) assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("ask") async for msg in resp: assert "answer" == msg.data await resp.close() async def test_ws_connect_with_wrong_ssl_type(aiohttp_client: AiohttpClient) -> None: app = web.Application() session = await aiohttp_client(app) with pytest.raises(TypeError, match="ssl should be SSLContext, .*"): await session.ws_connect("/", ssl=42) async def test_websocket_connection_not_closed_properly( aiohttp_client: AiohttpClient, ) -> None: """Test that closing the connection via __del__ does not raise an exception.""" async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) await ws.close() assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") assert resp._conn is not None # Simulate the connection not being closed properly # https://github.com/aio-libs/aiohttp/issues/9880 resp._conn.release() # Clean up so the test does not leak await resp.close() async def test_websocket_connection_cancellation( aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop ) -> None: """Test canceling the WebSocket connection task does not raise an exception in __del__.""" async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() await ws.prepare(request) await ws.close() assert False app = web.Application() app.router.add_route("GET", "/", handler) sync_future: asyncio.Future[list[aiohttp.ClientWebSocketResponse[bool]]] = ( loop.create_future() ) client = await aiohttp_client(app) async def websocket_task() -> None: resp = await client.ws_connect("/") assert resp is not None # ensure we hold a reference to the response # The test harness will cleanup the unclosed websocket # for us, so we need to copy the websockets to ensure # we can control the cleanup sync_future.set_result(client._websockets.copy()) client._websockets.clear() await asyncio.sleep(0) task = loop.create_task(websocket_task()) websockets = await sync_future task.cancel() with pytest.raises(asyncio.CancelledError): await task websocket = websockets.pop() # Call the `__del__` methods manually since when it gets gc'd it not reproducible del websocket._response # Cleanup properly websocket._response = mock.Mock() await websocket.close() async def test_receive_text_as_bytes_client_side(aiohttp_client: AiohttpClient) -> None: """Test client receiving TEXT messages as raw bytes with decode_text=False.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() await ws.send_str(msg + "/answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) # Connect with decode_text=False resp = await client.ws_connect("/", decode_text=False) await resp.send_str("ask") # Receive TEXT message as bytes msg = await resp.receive() assert msg.type is WSMsgType.TEXT assert isinstance(msg.data, bytes) assert msg.data == b"ask/answer" await resp.close() async def test_receive_text_as_bytes_server_side(aiohttp_client: AiohttpClient) -> None: """Test server receiving TEXT messages as raw bytes with decode_text=False.""" async def handler(request: web.Request) -> web.WebSocketResponse[Literal[False]]: ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( decode_text=False ) await ws.prepare(request) # Receive TEXT message as bytes msg = await ws.receive() assert msg.type is WSMsgType.TEXT assert isinstance(msg.data, bytes) assert msg.data == b"test message" # Send response await ws.send_bytes(msg.data + b"/reply") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("test message") msg = await resp.receive() assert msg.type is WSMsgType.BINARY assert msg.data == b"test message/reply" await resp.close() async def test_receive_text_as_bytes_json_parsing( aiohttp_client: AiohttpClient, ) -> None: """Test using orjson or similar parsers with raw bytes from TEXT messages.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() data = json.loads(msg) await ws.send_str(json.dumps({"response": data["value"] * 2})) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) # Connect with decode_text=False to get raw bytes resp = await client.ws_connect("/", decode_text=False) await resp.send_str(json.dumps({"value": 42})) # Receive TEXT message as bytes msg = await resp.receive() assert msg.type is WSMsgType.TEXT assert isinstance(msg.data, bytes) # Parse JSON using msg.json() method (covers WSMessageTextBytes.json()) data = msg.json() assert data == {"response": 84} await resp.close() async def test_decode_text_default_true(aiohttp_client: AiohttpClient) -> None: """Test that decode_text defaults to True for backward compatibility.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() await ws.send_str(msg + "/reply") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) # Default behavior (decode_text=True) resp = await client.ws_connect("/") await resp.send_str("test") # Should receive TEXT message as string msg = await resp.receive() assert msg.type is WSMsgType.TEXT assert isinstance(msg.data, str) assert msg.data == "test/reply" await resp.close() async def test_receive_str_returns_bytes_with_decode_text_false( aiohttp_client: AiohttpClient, ) -> None: """Test that receive_str() returns bytes when decode_text=False.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.send_str("hello world") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.ws_connect("/", decode_text=False) as ws: # receive_str() should return bytes when decode_text=False data = await ws.receive_str() assert isinstance(data, bytes) assert data == b"hello world" async def test_receive_str_returns_str_with_decode_text_true( aiohttp_client: AiohttpClient, ) -> None: """Test that receive_str() returns str when decode_text=True (default).""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.send_str("hello world") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.ws_connect("/") as ws: # receive_str() should return str when decode_text=True (default) data = await ws.receive_str() assert isinstance(data, str) assert data == "hello world" async def test_receive_json_with_orjson_style_loads( aiohttp_client: AiohttpClient, ) -> None: """Test receive_json() with orjson-style loads that accepts bytes.""" def orjson_style_loads(data: bytes) -> dict[str, int]: """Mock orjson.loads that accepts bytes.""" assert isinstance(data, bytes) result: dict[str, int] = json.loads(data) return result async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.send_str('{"value": 42}') await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.ws_connect("/", decode_text=False) as ws: # receive_json() with orjson-style loads should work with bytes data = await ws.receive_json(loads=orjson_style_loads) assert data == {"value": 42} ================================================ FILE: tests/test_compression_utils.py ================================================ """Tests for compression utils.""" import pytest from aiohttp.compression_utils import ZLibBackend, ZLibCompressor, ZLibDecompressor @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_compression_round_trip_in_executor() -> None: """Ensure that compression and decompression work correctly in the executor.""" compressor = ZLibCompressor( strategy=ZLibBackend.Z_DEFAULT_STRATEGY, max_sync_chunk_size=1 ) assert type(compressor._compressor) is type(ZLibBackend.compressobj()) decompressor = ZLibDecompressor(max_sync_chunk_size=1) assert type(decompressor._decompressor) is type(ZLibBackend.decompressobj()) data = b"Hi" * 100 compressed_data = await compressor.compress(data) + compressor.flush() decompressed_data = await decompressor.decompress(compressed_data) assert data == decompressed_data @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_compression_round_trip_in_event_loop() -> None: """Ensure that compression and decompression work correctly in the event loop.""" compressor = ZLibCompressor( strategy=ZLibBackend.Z_DEFAULT_STRATEGY, max_sync_chunk_size=10000 ) assert type(compressor._compressor) is type(ZLibBackend.compressobj()) decompressor = ZLibDecompressor(max_sync_chunk_size=10000) assert type(decompressor._decompressor) is type(ZLibBackend.decompressobj()) data = b"Hi" * 100 compressed_data = await compressor.compress(data) + compressor.flush() decompressed_data = await decompressor.decompress(compressed_data) assert data == decompressed_data ================================================ FILE: tests/test_connector.py ================================================ # Tests of http client with custom Connector import asyncio import contextlib import gc import hashlib import platform import socket import ssl import sys import uuid import warnings from collections import defaultdict, deque from collections.abc import Awaitable, Callable, Iterator, Sequence from concurrent import futures from contextlib import closing, suppress from typing import Any, Literal, NoReturn from unittest import mock import pytest from multidict import CIMultiDict from pytest_mock import MockerFixture from yarl import URL import aiohttp from aiohttp import ( ClientRequest, ClientSession, ClientTimeout, connector as connector_module, hdrs, web, ) from aiohttp.abc import AbstractResolver, ResolveResult from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequestArgs, ConnectionKey from aiohttp.connector import ( _SSL_CONTEXT_UNVERIFIED, _SSL_CONTEXT_VERIFIED, AddrInfoType, Connection, TCPConnector, _ConnectTunnelConnection, _DNSCacheTable, ) from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.resolver import AsyncResolver from aiohttp.test_utils import unused_port from aiohttp.tracing import Trace if sys.version_info >= (3, 11): from typing import Unpack _RequestMaker = Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest] else: _RequestMaker = Any @pytest.fixture def key() -> ConnectionKey: # Connection key return ConnectionKey("localhost", 80, False, True, None, None, None) @pytest.fixture def key2() -> ConnectionKey: # Connection key return ConnectionKey("localhost", 80, False, True, None, None, None) @pytest.fixture def other_host_key2() -> ConnectionKey: # Connection key return ConnectionKey("otherhost", 80, False, True, None, None, None) @pytest.fixture def ssl_key() -> ConnectionKey: # Connection key return ConnectionKey("localhost", 80, True, True, None, None, None) @pytest.fixture def unix_server( loop: asyncio.AbstractEventLoop, unix_sockname: str ) -> Iterator[Callable[[web.Application], Awaitable[None]]]: runners = [] async def go(app: web.Application) -> None: runner = web.AppRunner(app) runners.append(runner) await runner.setup() site = web.UnixSite(runner, unix_sockname) await site.start() yield go for runner in runners: loop.run_until_complete(runner.cleanup()) @pytest.fixture def named_pipe_server( proactor_loop: asyncio.AbstractEventLoop, pipe_name: str ) -> Iterator[Callable[[web.Application], Awaitable[None]]]: runners = [] async def go(app: web.Application) -> None: runner = web.AppRunner(app) runners.append(runner) await runner.setup() site = web.NamedPipeSite(runner, pipe_name) await site.start() yield go for runner in runners: proactor_loop.run_until_complete(runner.cleanup()) def create_mocked_conn( conn_closing_result: asyncio.AbstractEventLoop | None = None, should_close: bool = True, **kwargs: object, ) -> mock.Mock: assert "loop" not in kwargs try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.get_event_loop() f = loop.create_future() proto: mock.Mock = mock.create_autospec( ResponseHandler, instance=True, should_close=should_close, closed=f ) f.set_result(conn_closing_result) return proto async def test_connection_del(loop: asyncio.AbstractEventLoop) -> None: connector = mock.Mock() key = mock.Mock() protocol = mock.Mock() loop.set_debug(False) conn = Connection(connector, key, protocol, loop=loop) exc_handler = mock.Mock() loop.set_exception_handler(exc_handler) with pytest.warns(ResourceWarning): del conn gc.collect() await asyncio.sleep(0) connector._release.assert_called_with(key, protocol, should_close=True) msg = { "message": mock.ANY, "client_connection": mock.ANY, } exc_handler.assert_called_with(loop, msg) def test_connection_del_loop_debug(loop: asyncio.AbstractEventLoop) -> None: connector = mock.Mock() key = mock.Mock() protocol = mock.Mock() loop.set_debug(True) conn = Connection(connector, key, protocol, loop=loop) exc_handler = mock.Mock() loop.set_exception_handler(exc_handler) with pytest.warns(ResourceWarning): del conn gc.collect() msg = { "message": mock.ANY, "client_connection": mock.ANY, "source_traceback": mock.ANY, } exc_handler.assert_called_with(loop, msg) def test_connection_del_loop_closed(loop: asyncio.AbstractEventLoop) -> None: connector = mock.Mock() key = mock.Mock() protocol = mock.Mock() loop.set_debug(True) conn = Connection(connector, key, protocol, loop=loop) exc_handler = mock.Mock() loop.set_exception_handler(exc_handler) loop.close() with pytest.warns(ResourceWarning): del conn gc.collect() assert not connector._release.called assert not exc_handler.called async def test_del(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: conn = aiohttp.BaseConnector() proto = create_mocked_conn(loop, should_close=False) conn._release(key, proto) conns_impl = conn._conns exc_handler = mock.Mock() loop.set_exception_handler(exc_handler) with pytest.warns(ResourceWarning): del conn gc.collect() assert not conns_impl proto.close.assert_called_with() msg = { "connector": mock.ANY, # conn was deleted "connections": mock.ANY, "message": "Unclosed connector", } exc_handler.assert_called_with(loop, msg) @pytest.mark.xfail async def test_del_with_scheduled_cleanup( # type: ignore[misc] loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: loop.set_debug(True) conn = aiohttp.BaseConnector(keepalive_timeout=0.01) transp = create_mocked_conn(loop) conn._conns[key] = deque([(transp, 123)]) conns_impl = conn._conns exc_handler = mock.Mock() loop.set_exception_handler(exc_handler) with pytest.warns(ResourceWarning): # obviously doesn't deletion because loop has a strong # reference to connector's instance method, isn't it? del conn await asyncio.sleep(0.01) gc.collect() assert not conns_impl transp.close.assert_called_with() msg = {"connector": mock.ANY, "message": "Unclosed connector"} # conn was deleted if loop.get_debug(): msg["source_traceback"] = mock.ANY exc_handler.assert_called_with(loop, msg) @pytest.mark.skipif( sys.implementation.name != "cpython", reason="CPython GC is required for the test" ) def test_del_with_closed_loop( # type: ignore[misc] loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: async def make_conn() -> aiohttp.BaseConnector: return aiohttp.BaseConnector() conn = loop.run_until_complete(make_conn()) transp = create_mocked_conn(loop) conn._conns[key] = deque([(transp, 123)]) conns_impl = conn._conns exc_handler = mock.Mock() loop.set_exception_handler(exc_handler) loop.close() with pytest.warns(ResourceWarning): del conn gc.collect() assert not conns_impl assert not transp.close.called assert exc_handler.called async def test_del_empty_connector(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() exc_handler = mock.Mock() loop.set_exception_handler(exc_handler) del conn assert not exc_handler.called async def test_create_conn() -> None: conn = aiohttp.BaseConnector() with pytest.raises(NotImplementedError): await conn._create_connection(object(), [], object()) # type: ignore[arg-type] await conn.close() async def test_async_context_manager(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() async with conn as c: assert conn is c assert conn.closed async def test_close(key: ConnectionKey) -> None: proto = create_mocked_conn() conn = aiohttp.BaseConnector() assert not conn.closed conn._conns[key] = deque([(proto, 0)]) await conn.close() assert not conn._conns assert proto.close.called assert conn.closed async def test_close_with_proto_closed_none(key: ConnectionKey) -> None: """Test close when protocol.closed is None.""" # Create protocols where closed property returns None proto1 = mock.create_autospec(ResponseHandler, instance=True) proto1.closed = None proto1.close = mock.Mock() proto2 = mock.create_autospec(ResponseHandler, instance=True) proto2.closed = None proto2.close = mock.Mock() conn = aiohttp.BaseConnector() conn._conns[key] = deque([(proto1, 0)]) conn._acquired.add(proto2) # Close the connector - this should handle the case where proto.closed is None await conn.close() # Verify close was called on both protocols assert proto1.close.called assert proto2.close.called assert conn.closed async def test_get(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: conn = aiohttp.BaseConnector() try: assert await conn._get(key, []) is None proto = create_mocked_conn(loop) conn._conns[key] = deque([(proto, loop.time())]) connection = await conn._get(key, []) assert connection is not None assert connection.protocol == proto connection.close() finally: await conn.close() async def test_get_unconnected_proto(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() key = ConnectionKey("localhost", 80, False, False, None, None, None) try: assert await conn._get(key, []) is None proto = create_mocked_conn(loop) conn._conns[key] = deque([(proto, loop.time())]) connection = await conn._get(key, []) assert connection is not None assert connection.protocol == proto connection.close() assert await conn._get(key, []) is None conn._conns[key] = deque([(proto, loop.time())]) proto.is_connected = lambda *args: False assert await conn._get(key, []) is None finally: await conn.close() async def test_get_unconnected_proto_ssl(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() key = ConnectionKey("localhost", 80, True, False, None, None, None) try: assert await conn._get(key, []) is None proto = create_mocked_conn(loop) conn._conns[key] = deque([(proto, loop.time())]) connection = await conn._get(key, []) assert connection is not None assert connection.protocol == proto connection.close() assert await conn._get(key, []) is None conn._conns[key] = deque([(proto, loop.time())]) proto.is_connected = lambda *args: False assert await conn._get(key, []) is None finally: await conn.close() async def test_get_expired(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() key = ConnectionKey("localhost", 80, False, False, None, None, None) try: assert await conn._get(key, []) is None proto = create_mocked_conn(loop) conn._conns[key] = deque([(proto, loop.time() - 1000)]) assert await conn._get(key, []) is None assert not conn._conns finally: await conn.close() @pytest.mark.usefixtures("enable_cleanup_closed") async def test_get_expired_ssl(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector(enable_cleanup_closed=True) key = ConnectionKey("localhost", 80, True, False, None, None, None) try: assert await conn._get(key, []) is None proto = create_mocked_conn(loop) transport = proto.transport conn._conns[key] = deque([(proto, loop.time() - 1000)]) assert await conn._get(key, []) is None assert not conn._conns assert conn._cleanup_closed_transports == [transport] finally: await conn.close() async def test_release_acquired(key: ConnectionKey) -> None: proto = create_mocked_conn() conn = aiohttp.BaseConnector(limit=5, limit_per_host=10) with mock.patch.object(conn, "_release_waiter", autospec=True, spec_set=True) as m: conn._acquired.add(proto) conn._acquired_per_host[key].add(proto) conn._release_acquired(key, proto) assert 0 == len(conn._acquired) assert 0 == len(conn._acquired_per_host) assert m.called conn._release_acquired(key, proto) assert 0 == len(conn._acquired) assert 0 == len(conn._acquired_per_host) await conn.close() async def test_release_acquired_closed(key: ConnectionKey) -> None: proto = create_mocked_conn() conn = aiohttp.BaseConnector(limit=5) with mock.patch.object(conn, "_release_waiter", autospec=True, spec_set=True) as m: conn._acquired.add(proto) conn._acquired_per_host[key].add(proto) conn._closed = True conn._release_acquired(key, proto) assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) assert not m.called await conn.close() async def test_release(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: conn = aiohttp.BaseConnector() with mock.patch.object(conn, "_release_waiter", autospec=True, spec_set=True) as m: proto = create_mocked_conn(loop, should_close=False) conn._acquired.add(proto) conn._acquired_per_host[key].add(proto) conn._release(key, proto) assert m.called assert conn._cleanup_handle is not None assert conn._conns[key][0][0] == proto assert conn._conns[key][0][1] == pytest.approx(loop.time(), abs=0.1) assert not conn._cleanup_closed_transports await conn.close() @pytest.mark.usefixtures("enable_cleanup_closed") async def test_release_ssl_transport( # type: ignore[misc] loop: asyncio.AbstractEventLoop, ssl_key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(enable_cleanup_closed=True) with mock.patch.object(conn, "_release_waiter", autospec=True, spec_set=True): proto = create_mocked_conn(loop) transport = proto.transport conn._acquired.add(proto) conn._acquired_per_host[ssl_key].add(proto) conn._release(ssl_key, proto, should_close=True) assert conn._cleanup_closed_transports == [transport] await conn.close() async def test_release_already_closed(key: ConnectionKey) -> None: conn = aiohttp.BaseConnector() proto = create_mocked_conn() conn._acquired.add(proto) await conn.close() with mock.patch.object( conn, "_release_acquired", autospec=True, spec_set=True ) as m1: with mock.patch.object( conn, "_release_waiter", autospec=True, spec_set=True ) as m2: conn._release(key, proto) assert not m1.called assert not m2.called async def test_release_waiter_no_limit( loop: asyncio.AbstractEventLoop, key: ConnectionKey, key2: ConnectionKey ) -> None: # limit is 0 conn = aiohttp.BaseConnector(limit=0) w = mock.Mock() w.done.return_value = False conn._waiters[key][w] = None conn._release_waiter() assert len(conn._waiters[key]) == 0 assert w.done.called await conn.close() async def test_release_waiter_first_available( loop: asyncio.AbstractEventLoop, key: ConnectionKey, key2: ConnectionKey ) -> None: conn = aiohttp.BaseConnector() w1, w2 = mock.Mock(), mock.Mock() w1.done.return_value = False w2.done.return_value = False conn._waiters[key][w2] = None conn._waiters[key2][w1] = None conn._release_waiter() assert ( w1.set_result.called and not w2.set_result.called or not w1.set_result.called and w2.set_result.called ) await conn.close() async def test_release_waiter_release_first( loop: asyncio.AbstractEventLoop, key: ConnectionKey, key2: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(limit=1) w1, w2 = mock.Mock(), mock.Mock() w1.done.return_value = False w2.done.return_value = False conn._waiters[key][w1] = None conn._waiters[key][w2] = None conn._release_waiter() assert w1.set_result.called assert not w2.set_result.called await conn.close() async def test_release_waiter_skip_done_waiter( loop: asyncio.AbstractEventLoop, key: ConnectionKey, key2: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(limit=1) w1, w2 = mock.Mock(), mock.Mock() w1.done.return_value = True w2.done.return_value = False conn._waiters[key][w1] = None conn._waiters[key][w2] = None conn._release_waiter() assert not w1.set_result.called assert w2.set_result.called await conn.close() async def test_release_waiter_per_host( loop: asyncio.AbstractEventLoop, key: ConnectionKey, key2: ConnectionKey ) -> None: # no limit conn = aiohttp.BaseConnector(limit=0, limit_per_host=2) w1, w2 = mock.Mock(), mock.Mock() w1.done.return_value = False w2.done.return_value = False conn._waiters[key][w1] = None conn._waiters[key2][w2] = None conn._release_waiter() assert (w1.set_result.called and not w2.set_result.called) or ( not w1.set_result.called and w2.set_result.called ) await conn.close() async def test_release_waiter_no_available( loop: asyncio.AbstractEventLoop, key: ConnectionKey, key2: ConnectionKey ) -> None: # limit is 0 conn = aiohttp.BaseConnector(limit=0) w = mock.Mock() w.done.return_value = False conn._waiters[key][w] = None with mock.patch.object( conn, "_available_connections", autospec=True, spec_set=True, return_value=0 ): conn._release_waiter() assert len(conn._waiters) == 1 assert not w.done.called await conn.close() async def test_release_close(key: ConnectionKey) -> None: conn = aiohttp.BaseConnector() proto = create_mocked_conn(should_close=True) conn._acquired.add(proto) conn._release(key, proto) assert not conn._conns assert proto.close.called await conn.close() async def test__release_acquired_per_host1( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(limit_per_host=10) conn._release_acquired(key, create_mocked_conn(loop)) assert len(conn._acquired_per_host) == 0 await conn.close() async def test__release_acquired_per_host2( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(limit_per_host=10) handler = create_mocked_conn(loop) conn._acquired_per_host[key].add(handler) conn._release_acquired(key, handler) assert len(conn._acquired_per_host) == 0 await conn.close() async def test__release_acquired_per_host3( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(limit_per_host=10) handler = create_mocked_conn(loop) handler2 = create_mocked_conn(loop) conn._acquired_per_host[key].add(handler) conn._acquired_per_host[key].add(handler2) conn._release_acquired(key, handler) assert len(conn._acquired_per_host) == 1 assert conn._acquired_per_host[key] == {handler2} await conn.close() async def test_tcp_connector_certificate_error( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock, make_client_request: _RequestMaker, ) -> None: req = make_client_request("GET", URL("https://127.0.0.1:443"), loop=loop) conn = aiohttp.TCPConnector() with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True, side_effect=ssl.CertificateError, ): with pytest.raises(aiohttp.ClientConnectorCertificateError) as ctx: await conn.connect(req, [], ClientTimeout()) assert isinstance(ctx.value, ssl.CertificateError) assert isinstance(ctx.value.certificate_error, ssl.CertificateError) assert isinstance(ctx.value, aiohttp.ClientSSLError) await conn.close() async def test_tcp_connector_server_hostname_default( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock, make_client_request: _RequestMaker, ) -> None: conn = aiohttp.TCPConnector() with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() req = make_client_request("GET", URL("https://127.0.0.1:443"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1" await conn.close() async def test_tcp_connector_server_hostname_override( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock, make_client_request: _RequestMaker, ) -> None: conn = aiohttp.TCPConnector() with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() req = make_client_request( "GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost" ) with closing(await conn.connect(req, [], ClientTimeout())): assert create_connection.call_args.kwargs["server_hostname"] == "localhost" await conn.close() async def test_tcp_connector_multiple_hosts_errors( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: conn = aiohttp.TCPConnector() ip1 = "192.168.1.1" ip2 = "192.168.1.2" ip3 = "192.168.1.3" ip4 = "192.168.1.4" ip5 = "192.168.1.5" ips = [ip1, ip2, ip3, ip4, ip5] addrs_tried = [] ips_tried = [] fingerprint = hashlib.sha256(b"foo").digest() req = make_client_request( "GET", URL("https://mocked.host"), ssl=aiohttp.Fingerprint(fingerprint), loop=loop, ) async def _resolve_host( host: str, port: int, traces: object = None ) -> list[ResolveResult]: return [ { "hostname": host, "host": ip, "port": port, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } for ip in ips ] os_error = certificate_error = ssl_error = fingerprint_error = False connected = False async def start_connection( addr_infos: Sequence[AddrInfoType], **kwargs: object ) -> socket.socket: first_addr_info = addr_infos[0] first_addr_info_addr = first_addr_info[-1] addrs_tried.append(first_addr_info_addr) mock_socket = mock.create_autospec(socket.socket, spec_set=True, instance=True) mock_socket.getpeername.return_value = first_addr_info_addr return mock_socket # type: ignore[no-any-return] async def create_connection( *args: object, sock: socket.socket | None = None, **kwargs: object ) -> tuple[ResponseHandler, ResponseHandler]: nonlocal os_error, certificate_error, ssl_error, fingerprint_error nonlocal connected assert isinstance(sock, socket.socket) addr_info = sock.getpeername() ip = addr_info[0] ips_tried.append(ip) if ip == ip1: os_error = True raise OSError if ip == ip2: certificate_error = True raise ssl.CertificateError if ip == ip3: ssl_error = True raise ssl.SSLError if ip == ip4: # Close the socket since we are not actually connecting # and we don't want to leak it. sock.close() fingerprint_error = True tr = create_mocked_conn(loop) pr = create_mocked_conn(loop) def get_extra_info(param: str) -> object: if param == "sslcontext": return True if param == "ssl_object": s = mock.Mock() s.getpeercert.return_value = b"not foo" return s if param == "peername": return ("192.168.1.5", 12345) if param == "socket": return sock assert False, param tr.get_extra_info = get_extra_info return tr, pr if ip == ip5: # Close the socket since we are not actually connecting # and we don't want to leak it. sock.close() connected = True tr = create_mocked_conn(loop) pr = create_mocked_conn(loop) def get_extra_info(param: str) -> object: if param == "sslcontext": return True if param == "ssl_object": s = mock.Mock() s.getpeercert.return_value = b"foo" return s assert False tr.get_extra_info = get_extra_info return tr, pr assert False with ( mock.patch.object( conn, "_resolve_host", autospec=True, spec_set=True, side_effect=_resolve_host, ), mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True, side_effect=create_connection, ), mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", start_connection ), ): established_connection = await conn.connect(req, [], ClientTimeout()) assert ips_tried == ips assert addrs_tried == [(ip, 443) for ip in ips] assert os_error assert certificate_error assert ssl_error assert fingerprint_error assert connected established_connection.close() await conn.close() @pytest.mark.parametrize( ("happy_eyeballs_delay"), [0.1, 0.25, None], ) async def test_tcp_connector_happy_eyeballs( # type: ignore[misc] loop: asyncio.AbstractEventLoop, happy_eyeballs_delay: float | None, make_client_request: _RequestMaker, ) -> None: conn = aiohttp.TCPConnector(happy_eyeballs_delay=happy_eyeballs_delay) ip1 = "dead::beef::" ip2 = "192.168.1.1" ips = [ip1, ip2] addrs_tried = [] req = make_client_request( "GET", URL("https://mocked.host"), loop=loop, ) async def _resolve_host( host: str, port: int, traces: object = None ) -> list[ResolveResult]: return [ { "hostname": host, "host": ip, "port": port, "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } for ip in ips ] os_error = False connected = False async def sock_connect(*args: tuple[str, int], **kwargs: object) -> None: addr = args[1] nonlocal os_error addrs_tried.append(addr) if addr[0] == ip1: os_error = True raise OSError async def create_connection( *args: object, sock: socket.socket | None = None, **kwargs: object ) -> tuple[ResponseHandler, ResponseHandler]: assert isinstance(sock, socket.socket) # Close the socket since we are not actually connecting # and we don't want to leak it. sock.close() nonlocal connected connected = True tr = create_mocked_conn(loop) pr = create_mocked_conn(loop) return tr, pr with mock.patch.object( conn, "_resolve_host", autospec=True, spec_set=True, side_effect=_resolve_host ): with mock.patch.object( conn._loop, "sock_connect", autospec=True, spec_set=True, side_effect=sock_connect, ): with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True, side_effect=create_connection, ): established_connection = await conn.connect(req, [], ClientTimeout()) assert addrs_tried == [(ip1, 443, 0, 0), (ip2, 443)] assert os_error assert connected established_connection.close() await conn.close() async def test_tcp_connector_interleave( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: conn = aiohttp.TCPConnector(interleave=2) ip1 = "192.168.1.1" ip2 = "192.168.1.2" ip3 = "dead::beef::" ip4 = "aaaa::beef::" ip5 = "192.168.1.5" ips = [ip1, ip2, ip3, ip4, ip5] success_ips = [] interleave_val = None req = make_client_request( "GET", URL("https://mocked.host"), loop=loop, ) async def _resolve_host( host: str, port: int, traces: object = None ) -> list[ResolveResult]: return [ { "hostname": host, "host": ip, "port": port, "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } for ip in ips ] async def start_connection( addr_infos: Sequence[AddrInfoType], *, interleave: int | None = None, **kwargs: object, ) -> socket.socket: nonlocal interleave_val interleave_val = interleave # Mock the 4th host connecting successfully fourth_addr_info = addr_infos[3] fourth_addr_info_addr = fourth_addr_info[-1] mock_socket = mock.create_autospec(socket.socket, spec_set=True, instance=True) mock_socket.getpeername.return_value = fourth_addr_info_addr return mock_socket # type: ignore[no-any-return] async def create_connection( *args: object, sock: socket.socket | None = None, **kwargs: object ) -> tuple[ResponseHandler, ResponseHandler]: assert isinstance(sock, socket.socket) addr_info = sock.getpeername() ip = addr_info[0] success_ips.append(ip) # Close the socket since we are not actually connecting # and we don't want to leak it. sock.close() tr = create_mocked_conn(loop) pr = create_mocked_conn(loop) return tr, pr with ( mock.patch.object( conn, "_resolve_host", autospec=True, spec_set=True, side_effect=_resolve_host, ), mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True, side_effect=create_connection, ), mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", start_connection ), ): established_connection = await conn.connect(req, [], ClientTimeout()) assert success_ips == [ip4] assert interleave_val == 2 established_connection.close() await conn.close() async def test_tcp_connector_family_is_respected( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: conn = aiohttp.TCPConnector(family=socket.AF_INET) ip1 = "dead::beef::" ip2 = "192.168.1.1" ips = [ip1, ip2] addrs_tried = [] req = make_client_request( "GET", URL("https://mocked.host"), loop=loop, ) async def _resolve_host( host: str, port: int, traces: object = None ) -> list[ResolveResult]: return [ { "hostname": host, "host": ip, "port": port, "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } for ip in ips ] connected = False async def sock_connect(*args: tuple[str, int], **kwargs: object) -> None: addr = args[1] addrs_tried.append(addr) async def create_connection( *args: object, sock: socket.socket | None = None, **kwargs: object ) -> tuple[ResponseHandler, ResponseHandler]: assert isinstance(sock, socket.socket) # Close the socket since we are not actually connecting # and we don't want to leak it. sock.close() nonlocal connected connected = True tr = create_mocked_conn(loop) pr = create_mocked_conn(loop) return tr, pr with mock.patch.object( conn, "_resolve_host", autospec=True, spec_set=True, side_effect=_resolve_host ): with mock.patch.object( conn._loop, "sock_connect", autospec=True, spec_set=True, side_effect=sock_connect, ): with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True, side_effect=create_connection, ): established_connection = await conn.connect(req, [], ClientTimeout()) # We should only try the IPv4 address since we specified # the family to be AF_INET assert addrs_tried == [(ip2, 443)] assert connected established_connection.close() @pytest.mark.parametrize( ("request_url"), [ ("http://mocked.host"), ("https://mocked.host"), ], ) async def test_tcp_connector_multiple_hosts_one_timeout( # type: ignore[misc] loop: asyncio.AbstractEventLoop, request_url: str, make_client_request: _RequestMaker, ) -> None: conn = aiohttp.TCPConnector() ip1 = "192.168.1.1" ip2 = "192.168.1.2" ips = [ip1, ip2] ips_tried = [] ips_success = [] timeout_error = False connected = False req = make_client_request( "GET", URL(request_url), loop=loop, ) async def _resolve_host( host: str, port: int, traces: object = None ) -> list[ResolveResult]: return [ { "hostname": host, "host": ip, "port": port, "family": socket.AF_INET6 if ":" in ip else socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } for ip in ips ] async def start_connection( addr_infos: Sequence[AddrInfoType], *, interleave: int | None = None, **kwargs: object, ) -> socket.socket: nonlocal timeout_error addr_info = addr_infos[0] addr_info_addr = addr_info[-1] ip = addr_info_addr[0] ips_tried.append(ip) if ip == ip1: timeout_error = True raise asyncio.TimeoutError if ip == ip2: mock_socket = mock.create_autospec( socket.socket, spec_set=True, instance=True ) mock_socket.getpeername.return_value = addr_info_addr return mock_socket # type: ignore[no-any-return] assert False async def create_connection( *args: object, sock: socket.socket | None = None, **kwargs: object ) -> tuple[ResponseHandler, ResponseHandler]: nonlocal connected assert isinstance(sock, socket.socket) addr_info = sock.getpeername() ip = addr_info[0] ips_success.append(ip) connected = True # Close the socket since we are not actually connecting # and we don't want to leak it. sock.close() tr = create_mocked_conn(loop) pr = create_mocked_conn(loop) return tr, pr with ( mock.patch.object( conn, "_resolve_host", autospec=True, spec_set=True, side_effect=_resolve_host, ), mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True, side_effect=create_connection, ), mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", start_connection ), ): established_connection = await conn.connect(req, [], ClientTimeout()) assert ips_tried == ips assert ips_success == [ip2] assert timeout_error assert connected established_connection.close() await conn.close() async def test_tcp_connector_resolve_host(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.TCPConnector(use_dns_cache=True) res = await conn._resolve_host("localhost", 8080) assert res for rec in res: if rec["family"] == socket.AF_INET: assert rec["host"] == "127.0.0.1" assert rec["hostname"] == "localhost" assert rec["port"] == 8080 else: assert rec["family"] == socket.AF_INET6 assert rec["hostname"] == "localhost" assert rec["port"] == 8080 if platform.system() == "Darwin": assert rec["host"] in ("::1", "fe80::1", "fe80::1%lo0") else: assert rec["host"] == "::1" await conn.close() @pytest.fixture def dns_response(loop: asyncio.AbstractEventLoop) -> Callable[[], Awaitable[list[str]]]: async def coro() -> list[str]: # simulates a network operation await asyncio.sleep(0) return ["127.0.0.1"] return coro async def test_tcp_connector_dns_cache_not_expired( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: mock_default_resolver = mock.create_autospec( AsyncResolver, instance=True, spec_set=True ) mock_default_resolver.resolve.return_value = await dns_response() m_resolver.return_value = mock_default_resolver conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) await conn.close() async def test_tcp_connector_dns_cache_forever( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: mock_default_resolver = mock.create_autospec( AsyncResolver, instance=True, spec_set=True ) mock_default_resolver.resolve.return_value = await dns_response() m_resolver.return_value = mock_default_resolver conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) mock_default_resolver.resolve.assert_called_once_with( "localhost", 8080, family=0 ) await conn.close() async def test_tcp_connector_use_dns_cache_disabled( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: mock_default_resolver = mock.create_autospec( AsyncResolver, instance=True, spec_set=True ) mock_default_resolver.resolve.side_effect = [ await dns_response(), await dns_response(), ] m_resolver.return_value = mock_default_resolver conn = aiohttp.TCPConnector(use_dns_cache=False) await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) mock_default_resolver.resolve.assert_has_calls( [ mock.call("localhost", 8080, family=0), mock.call("localhost", 8080, family=0), ] ) await conn.close() async def test_tcp_connector_dns_throttle_requests( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: mock_default_resolver = mock.create_autospec( AbstractResolver, instance=True, spec_set=True ) async def mock_resolve(*_args: object, **_kwargs: object) -> list[str]: return await dns_response() mock_default_resolver.resolve.side_effect = mock_resolve m_resolver.return_value = mock_default_resolver conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) t = loop.create_task(conn._resolve_host("localhost", 8080)) t2 = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) await asyncio.sleep(0) mock_default_resolver.resolve.assert_called_once_with( "localhost", 8080, family=0 ) t.cancel() t2.cancel() with pytest.raises(asyncio.CancelledError): await asyncio.gather(t, t2) await conn.close() async def test_tcp_connector_dns_throttle_requests_exception_spread( loop: asyncio.AbstractEventLoop, ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: e = Exception() mock_resolver_instance = mock.create_autospec( AbstractResolver, instance=True, spec_set=True ) mock_resolver_instance.resolve.side_effect = e m_resolver.return_value = mock_resolver_instance conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) r1 = loop.create_task(conn._resolve_host("localhost", 8080)) r2 = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) await asyncio.sleep(0) await asyncio.sleep(0) await asyncio.sleep(0) assert r1.exception() == e assert r2.exception() == e await conn.close() async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: async def mock_resolve(*_args: object, **_kwargs: object) -> list[str]: return await dns_response() mock_default_resolver = mock.create_autospec( AbstractResolver, instance=True, spec_set=True ) mock_default_resolver.resolve.side_effect = mock_resolve m_resolver.return_value = mock_default_resolver conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) t = loop.create_task(conn._resolve_host("localhost", 8080)) f = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) await asyncio.sleep(0) await conn.close() t.cancel() with pytest.raises(asyncio.CancelledError): await asyncio.gather(t, f) await conn.close() @pytest.fixture def dns_response_error( loop: asyncio.AbstractEventLoop, ) -> Callable[[], Awaitable[NoReturn]]: async def coro() -> NoReturn: # simulates a network operation await asyncio.sleep(0) raise socket.gaierror(-3, "Temporary failure in name resolution") return coro async def test_tcp_connector_cancel_dns_error_captured( loop: asyncio.AbstractEventLoop, dns_response_error: Callable[[], Awaitable[NoReturn]], make_client_request: _RequestMaker, ) -> None: exception_handler_called = False def exception_handler(loop: asyncio.AbstractEventLoop, context: object) -> None: nonlocal exception_handler_called exception_handler_called = True loop.set_exception_handler(mock.Mock(side_effect=exception_handler)) with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: req = make_client_request("GET", URL("http://temporary-failure:80"), loop=loop) conn = aiohttp.TCPConnector( use_dns_cache=False, ) m_resolver().resolve.return_value = dns_response_error() m_resolver().close = mock.AsyncMock() f = loop.create_task(conn._create_direct_connection(req, [], ClientTimeout())) await asyncio.sleep(0) f.cancel() with pytest.raises(asyncio.CancelledError): await f gc.collect() assert exception_handler_called is False await conn.close() async def test_tcp_connector_dns_tracing( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_dns_resolvehost_start = mock.AsyncMock() on_dns_resolvehost_end = mock.AsyncMock() on_dns_cache_hit = mock.AsyncMock() on_dns_cache_miss = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) trace_config.on_dns_cache_hit.append(on_dns_cache_hit) trace_config.on_dns_cache_miss.append(on_dns_cache_miss) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080, traces=traces) on_dns_resolvehost_start.assert_called_once_with( session, trace_config_ctx, aiohttp.TraceDnsResolveHostStartParams("localhost"), ) on_dns_resolvehost_end.assert_called_once_with( session, trace_config_ctx, aiohttp.TraceDnsResolveHostEndParams("localhost") ) on_dns_cache_miss.assert_called_once_with( session, trace_config_ctx, aiohttp.TraceDnsCacheMissParams("localhost") ) assert not on_dns_cache_hit.called await conn._resolve_host("localhost", 8080, traces=traces) on_dns_cache_hit.assert_called_once_with( session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost") ) await conn.close() async def test_tcp_connector_dns_tracing_cache_disabled( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_dns_resolvehost_start = mock.AsyncMock() on_dns_resolvehost_end = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=False) m_resolver().resolve.side_effect = [dns_response(), dns_response()] m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080, traces=traces) await conn._resolve_host("localhost", 8080, traces=traces) on_dns_resolvehost_start.assert_has_calls( [ mock.call( session, trace_config_ctx, aiohttp.TraceDnsResolveHostStartParams("localhost"), ), mock.call( session, trace_config_ctx, aiohttp.TraceDnsResolveHostStartParams("localhost"), ), ] ) on_dns_resolvehost_end.assert_has_calls( [ mock.call( session, trace_config_ctx, aiohttp.TraceDnsResolveHostEndParams("localhost"), ), mock.call( session, trace_config_ctx, aiohttp.TraceDnsResolveHostEndParams("localhost"), ), ] ) await conn.close() async def test_tcp_connector_dns_tracing_throttle_requests( loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_dns_cache_hit = mock.AsyncMock() on_dns_cache_miss = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_dns_cache_hit.append(on_dns_cache_hit) trace_config.on_dns_cache_miss.append(on_dns_cache_miss) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() m_resolver().close = mock.AsyncMock() t = loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) t1 = loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) await asyncio.sleep(0) await asyncio.sleep(0) on_dns_cache_hit.assert_called_once_with( session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost") ) on_dns_cache_miss.assert_called_once_with( session, trace_config_ctx, aiohttp.TraceDnsCacheMissParams("localhost") ) t.cancel() t1.cancel() with pytest.raises(asyncio.CancelledError): await asyncio.gather(t, t1) await conn.close() async def test_tcp_connector_close_resolver() -> None: m_resolver = mock.create_autospec(AbstractResolver, instance=True, spec_set=True) with mock.patch("aiohttp.connector.DefaultResolver", return_value=m_resolver): conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) await conn.close() m_resolver.close.assert_awaited_once() async def test_dns_error( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: connector = aiohttp.TCPConnector() with mock.patch.object( connector, "_resolve_host", autospec=True, spec_set=True, side_effect=OSError("dont take it serious"), ): req = make_client_request("GET", URL("http://www.python.org"), loop=loop) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) await connector.close() async def test_get_pop_empty_conns( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: # see issue #473 conn = aiohttp.BaseConnector() assert await conn._get(key, []) is None assert not conn._conns await conn.close() async def test_release_close_do_not_add_to_pool( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: # see issue #473 conn = aiohttp.BaseConnector() proto = create_mocked_conn(loop, should_close=True) conn._acquired.add(proto) conn._release(key, proto) assert not conn._conns await conn.close() async def test_release_close_do_not_delete_existing_connections( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: proto1 = create_mocked_conn(loop) conn = aiohttp.BaseConnector() conn._conns[key] = deque([(proto1, 1)]) proto = create_mocked_conn(loop, should_close=True) conn._acquired.add(proto) conn._release(key, proto) assert conn._conns[key] == deque([(proto1, 1)]) assert proto.close.called await conn.close() async def test_release_not_started( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector() proto = create_mocked_conn(should_close=False) conn._acquired.add(proto) conn._release(key, proto) # assert conn._conns == {key: [(proto, 10)]} rec = conn._conns[key] assert rec[0][0] == proto assert rec[0][1] == pytest.approx(loop.time(), abs=0.05) assert not proto.close.called await conn.close() async def test_release_not_opened( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector() proto = create_mocked_conn(loop) conn._acquired.add(proto) conn._release(key, proto) assert proto.close.called await conn.close() async def test_connect( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector() conn._conns[key] = deque([(proto, loop.time())]) with mock.patch.object(conn, "_create_connection", create_mocked_conn(loop)) as m: m.return_value = loop.create_future() m.return_value.set_result(proto) connection = await conn.connect(req, [], ClientTimeout()) assert not m.called assert connection._protocol is proto assert connection.transport is proto.transport assert isinstance(connection, Connection) connection.close() await conn.close() async def test_connect_tracing( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_connection_create_start = mock.AsyncMock() on_connection_create_end = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_connection_create_start.append(on_connection_create_start) trace_config.on_connection_create_end.append(on_connection_create_end) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector() with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): conn2 = await conn.connect(req, traces, ClientTimeout()) conn2.release() on_connection_create_start.assert_called_with( session, trace_config_ctx, aiohttp.TraceConnectionCreateStartParams() ) on_connection_create_end.assert_called_with( session, trace_config_ctx, aiohttp.TraceConnectionCreateEndParams() ) @pytest.mark.parametrize( "signal", [ "on_connection_create_start", "on_connection_create_end", ], ) async def test_exception_during_connetion_create_tracing( # type: ignore[misc] loop: asyncio.AbstractEventLoop, signal: str, make_client_request: _RequestMaker ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_signal = mock.AsyncMock(side_effect=asyncio.CancelledError) trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) getattr(trace_config, signal).append(on_signal) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) key = req.connection_key conn = aiohttp.BaseConnector() assert not conn._acquired assert key not in conn._acquired_per_host with ( pytest.raises(asyncio.CancelledError), mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ), ): await conn.connect(req, traces, ClientTimeout()) assert not conn._acquired assert key not in conn._acquired_per_host async def test_exception_during_connection_queued_tracing( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_signal = mock.AsyncMock(side_effect=asyncio.CancelledError) trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_connection_queued_start.append(on_signal) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) key = req.connection_key conn = aiohttp.BaseConnector(limit=1) assert not conn._acquired assert key not in conn._acquired_per_host with ( pytest.raises(asyncio.CancelledError), mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ), ): resp1 = await conn.connect(req, traces, ClientTimeout()) assert resp1 # 2nd connect request will be queued await conn.connect(req, traces, ClientTimeout()) resp1.close() assert not conn._waiters assert not conn._acquired assert key not in conn._acquired_per_host await conn.close() async def test_exception_during_connection_reuse_tracing( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_signal = mock.AsyncMock(side_effect=asyncio.CancelledError) trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_connection_reuseconn.append(on_signal) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) key = req.connection_key conn = aiohttp.BaseConnector() assert not conn._acquired assert key not in conn._acquired_per_host with ( pytest.raises(asyncio.CancelledError), mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ), ): resp = await conn.connect(req, traces, ClientTimeout()) with mock.patch.object(resp.protocol, "should_close", False): resp.release() assert not conn._acquired assert key not in conn._acquired_per_host assert key in conn._conns await conn.connect(req, traces, ClientTimeout()) assert not conn._acquired assert key not in conn._acquired_per_host async def test_cancellation_during_waiting_for_free_connection( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() waiter_wait_stated_future = loop.create_future() async def on_connection_queued_start(*args: object, **kwargs: object) -> None: waiter_wait_stated_future.set_result(None) trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_connection_queued_start.append(on_connection_queued_start) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) key = req.connection_key conn = aiohttp.BaseConnector(limit=1) assert not conn._acquired assert key not in conn._acquired_per_host with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): resp1 = await conn.connect(req, traces, ClientTimeout()) assert resp1 # 2nd connect request will be queued task = asyncio.create_task(conn.connect(req, traces, ClientTimeout())) await waiter_wait_stated_future list(conn._waiters[key])[0].cancel() with pytest.raises(asyncio.CancelledError): await task resp1.close() assert not conn._waiters assert not conn._acquired assert key not in conn._acquired_per_host async def test_close_during_connect( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True fut = loop.create_future() req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector() with mock.patch.object(conn, "_create_connection", lambda *args: fut): task = loop.create_task(conn.connect(req, [], ClientTimeout())) await asyncio.sleep(0) await conn.close() fut.set_result(proto) with pytest.raises(aiohttp.ClientConnectionError): await task assert proto.close.called @pytest.mark.usefixtures("enable_cleanup_closed") async def test_ctor_cleanup() -> None: loop = mock.Mock() loop.time.return_value = 1.5 conn = aiohttp.BaseConnector(keepalive_timeout=10, enable_cleanup_closed=True) assert conn._cleanup_handle is None assert conn._cleanup_closed_handle is not None await conn.close() async def test_cleanup(key: ConnectionKey) -> None: # The test sets the clock to 300s. It starts with 2 connections in the # pool. The first connection has use time of 10s. When cleanup reaches it, # it computes the deadline = 300 - 15.0 = 285.0 (15s being the default # keep-alive timeout value), then checks that it's overdue because # 10 - 285.0 < 0, and releases it since it's in connected state. The second # connection, though, is in disconnected state so it doesn't bother to # check if it's past due and closes the underlying transport. m1 = mock.Mock() m2 = mock.Mock() m1.is_connected.return_value = True m2.is_connected.return_value = False testset: defaultdict[ConnectionKey, deque[tuple[ResponseHandler, float]]] = ( defaultdict(deque) ) testset[key] = deque([(m1, 10), (m2, 300)]) loop = mock.Mock() loop.time.return_value = 300 async with aiohttp.BaseConnector() as conn: conn._loop = loop conn._conns = testset existing_handle = conn._cleanup_handle = mock.Mock() with mock.patch("aiohttp.connector.monotonic", return_value=300): conn._cleanup() assert existing_handle.cancel.called assert conn._conns == {} assert conn._cleanup_handle is None @pytest.mark.usefixtures("enable_cleanup_closed") async def test_cleanup_close_ssl_transport( # type: ignore[misc] loop: asyncio.AbstractEventLoop, ssl_key: ConnectionKey ) -> None: proto = create_mocked_conn(loop) transport = proto.transport testset: defaultdict[ConnectionKey, deque[tuple[ResponseHandler, float]]] = ( defaultdict(deque) ) testset[ssl_key] = deque([(proto, 10)]) loop = mock.Mock() new_time = asyncio.get_event_loop().time() + 300 loop.time.return_value = new_time conn = aiohttp.BaseConnector(enable_cleanup_closed=True) conn._loop = loop conn._conns = testset existing_handle = conn._cleanup_handle = mock.Mock() with mock.patch("aiohttp.connector.monotonic", return_value=new_time): conn._cleanup() assert existing_handle.cancel.called assert conn._conns == {} assert conn._cleanup_closed_transports == [transport] await conn.close() await asyncio.sleep(0) # Give cleanup a chance to close transports async def test_cleanup2(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: m = create_mocked_conn() m.is_connected.return_value = True testset: defaultdict[ConnectionKey, deque[tuple[ResponseHandler, float]]] = ( defaultdict(deque) ) testset[key] = deque([(m, 300)]) conn = aiohttp.BaseConnector(keepalive_timeout=10) conn._loop = mock.Mock() conn._loop.time.return_value = 300 with mock.patch("aiohttp.connector.monotonic", return_value=300): conn._conns = testset conn._cleanup() assert conn._conns == testset assert conn._cleanup_handle is not None conn._loop.call_at.assert_called_with(310, mock.ANY, mock.ANY) await conn.close() async def test_cleanup3(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: m = create_mocked_conn(loop) m.is_connected.return_value = True testset: defaultdict[ConnectionKey, deque[tuple[ResponseHandler, float]]] = ( defaultdict(deque) ) testset[key] = deque([(m, 290.1), (create_mocked_conn(loop), 305.1)]) conn = aiohttp.BaseConnector(keepalive_timeout=10) conn._loop = mock.Mock() conn._loop.time.return_value = 308.5 conn._conns = testset with mock.patch("aiohttp.connector.monotonic", return_value=308.5): conn._cleanup() assert conn._conns == {key: deque([testset[key][1]])} assert conn._cleanup_handle is not None conn._loop.call_at.assert_called_with(319, mock.ANY, mock.ANY) await conn.close() @pytest.mark.usefixtures("enable_cleanup_closed") async def test_cleanup_closed( loop: asyncio.AbstractEventLoop, mocker: MockerFixture ) -> None: m = mocker.spy(loop, "call_at") conn = aiohttp.BaseConnector(enable_cleanup_closed=True) tr = mock.Mock() conn._cleanup_closed_handle = cleanup_closed_handle = mock.Mock() conn._cleanup_closed_transports = [tr] conn._cleanup_closed() assert tr.abort.called assert not conn._cleanup_closed_transports assert m.called assert cleanup_closed_handle.cancel.called await conn.close() async def test_cleanup_closed_is_noop_on_fixed_cpython() -> None: """Ensure that enable_cleanup_closed is a noop on fixed Python versions.""" with ( mock.patch("aiohttp.connector.NEEDS_CLEANUP_CLOSED", False), pytest.warns(DeprecationWarning, match="cleanup_closed ignored"), ): conn = aiohttp.BaseConnector(enable_cleanup_closed=True) assert conn._cleanup_closed_disabled is True async def test_cleanup_closed_disabled( loop: asyncio.AbstractEventLoop, mocker: MockerFixture ) -> None: conn = aiohttp.BaseConnector(enable_cleanup_closed=False) tr = mock.Mock() conn._cleanup_closed_transports = [tr] conn._cleanup_closed() assert tr.abort.called assert not conn._cleanup_closed_transports await conn.close() async def test_tcp_connector_ctor(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.TCPConnector() assert conn._ssl is True assert conn.use_dns_cache assert conn.family == 0 await conn.close() @pytest.mark.skipif( sys.version_info < (3, 11), reason="Use test_tcp_connector_ssl_shutdown_timeout_pre_311 for Python < 3.11", ) async def test_tcp_connector_ssl_shutdown_timeout( loop: asyncio.AbstractEventLoop, ) -> None: # Test default value (no warning expected) conn = aiohttp.TCPConnector() assert conn._ssl_shutdown_timeout == 0 await conn.close() # Test custom value - expect deprecation warning with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=1.0) assert conn._ssl_shutdown_timeout == 1.0 await conn.close() # Test None value - expect deprecation warning with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=None) assert conn._ssl_shutdown_timeout is None await conn.close() @pytest.mark.skipif( sys.version_info >= (3, 11), reason="This test is for Python < 3.11 runtime warning behavior", ) async def test_tcp_connector_ssl_shutdown_timeout_pre_311( loop: asyncio.AbstractEventLoop, ) -> None: """Test that both deprecation and runtime warnings are issued on Python < 3.11.""" # Test custom value - expect both deprecation and runtime warnings with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") conn = aiohttp.TCPConnector(ssl_shutdown_timeout=1.0) # Should have both deprecation and runtime warnings assert len(w) == 2 assert any(issubclass(warn.category, DeprecationWarning) for warn in w) assert any(issubclass(warn.category, RuntimeWarning) for warn in w) assert conn._ssl_shutdown_timeout == 1.0 await conn.close() @pytest.mark.skipif( sys.version_info < (3, 11), reason="ssl_shutdown_timeout requires Python 3.11+" ) async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( # type: ignore[misc] loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock, make_client_request: _RequestMaker, ) -> None: # Test that ssl_shutdown_timeout is passed to create_connection for SSL connections with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert create_connection.call_args.kwargs["ssl_shutdown_timeout"] == 2.5 await conn.close() # Test with None value with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=None) with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): # When ssl_shutdown_timeout is None, it should not be in kwargs assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs await conn.close() # Test that ssl_shutdown_timeout is NOT passed for non-SSL connections with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() req = make_client_request("GET", URL("http://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): # For non-SSL connections, ssl_shutdown_timeout should not be passed assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs await conn.close() @pytest.mark.skipif(sys.version_info >= (3, 11), reason="Test for Python < 3.11") async def test_tcp_connector_ssl_shutdown_timeout_not_passed_pre_311( # type: ignore[misc] loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock, make_client_request: _RequestMaker, ) -> None: # Test that ssl_shutdown_timeout is NOT passed to create_connection on Python < 3.11 with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) # Should have both deprecation and runtime warnings assert len(w) == 2 assert any(issubclass(warn.category, DeprecationWarning) for warn in w) assert any(issubclass(warn.category, RuntimeWarning) for warn in w) with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() # Test with HTTPS req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs # Test with HTTP req = make_client_request("GET", URL("http://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs await conn.close() async def test_tcp_connector_close_abort_ssl_when_shutdown_timeout_zero( loop: asyncio.AbstractEventLoop, ) -> None: """Test that close() uses abort() for SSL connections when ssl_shutdown_timeout=0.""" with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=0) # Create a mock SSL protocol proto = mock.create_autospec(ResponseHandler, instance=True) proto.closed = None # Create mock SSL transport transport = mock.Mock() transport.get_extra_info.return_value = mock.Mock() # Returns SSL context transport.is_closing.return_value = False proto.transport = transport # Add the protocol to acquired connections conn._acquired.add(proto) # Close the connector await conn.close() # Verify abort was called instead of close for SSL connection proto.abort.assert_called_once() proto.close.assert_not_called() async def test_tcp_connector_close_doesnt_abort_non_ssl_when_shutdown_timeout_zero( loop: asyncio.AbstractEventLoop, ) -> None: """Test that close() still uses close() for non-SSL connections even when ssl_shutdown_timeout=0.""" with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=0) # Create a mock non-SSL protocol proto = mock.create_autospec(ResponseHandler, instance=True) proto.closed = None # Create mock non-SSL transport transport = mock.Mock() transport.get_extra_info.return_value = None # No SSL context transport.is_closing.return_value = False proto.transport = transport # Add the protocol to acquired connections conn._acquired.add(proto) # Close the connector await conn.close() # Verify close was called for non-SSL connection proto.close.assert_called_once() proto.abort.assert_not_called() async def test_tcp_connector_ssl_shutdown_timeout_warning_pre_311( loop: asyncio.AbstractEventLoop, ) -> None: """Test that a warning is issued for non-zero ssl_shutdown_timeout on Python < 3.11.""" with ( mock.patch.object(sys, "version_info", (3, 10, 0)), warnings.catch_warnings(record=True) as w, ): warnings.simplefilter("always") conn = aiohttp.TCPConnector(ssl_shutdown_timeout=5.0) # We should get two warnings: deprecation and runtime warning assert len(w) == 2 # Find each warning type deprecation_warning = next( (warn for warn in w if issubclass(warn.category, DeprecationWarning)), None ) runtime_warning = next( (warn for warn in w if issubclass(warn.category, RuntimeWarning)), None ) assert deprecation_warning is not None assert "ssl_shutdown_timeout parameter is deprecated" in str( deprecation_warning.message ) assert runtime_warning is not None assert "ssl_shutdown_timeout=5.0 is ignored on Python < 3.11" in str( runtime_warning.message ) assert "only ssl_shutdown_timeout=0 is supported" in str( runtime_warning.message ) # Verify the value is still stored assert conn._ssl_shutdown_timeout == 5.0 await conn.close() async def test_tcp_connector_ssl_shutdown_timeout_zero_no_warning_pre_311( loop: asyncio.AbstractEventLoop, ) -> None: """Test that no warning is issued for ssl_shutdown_timeout=0 on Python < 3.11.""" with ( mock.patch.object(sys, "version_info", (3, 10, 0)), warnings.catch_warnings(record=True) as w, ): warnings.simplefilter("always") conn = aiohttp.TCPConnector(ssl_shutdown_timeout=0) # We should get one warning: deprecation assert len(w) == 1 assert issubclass(w[0].category, DeprecationWarning) assert "ssl_shutdown_timeout parameter is deprecated" in str(w[0].message) assert conn._ssl_shutdown_timeout == 0 await conn.close() async def test_tcp_connector_ssl_shutdown_timeout_sentinel_no_warning_pre_311( loop: asyncio.AbstractEventLoop, ) -> None: """Test that no warning is issued when sentinel is used on Python < 3.11.""" with ( mock.patch.object(sys, "version_info", (3, 10, 0)), warnings.catch_warnings(record=True) as w, ): warnings.simplefilter("always") conn = aiohttp.TCPConnector() # Uses sentinel by default assert len(w) == 0 assert conn._ssl_shutdown_timeout == 0 # Default value await conn.close() async def test_tcp_connector_ssl_shutdown_timeout_zero_not_passed( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock, make_client_request: _RequestMaker, ) -> None: """Test that ssl_shutdown_timeout=0 is NOT passed to create_connection.""" with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=0) with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() # Test with HTTPS req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): # Verify ssl_shutdown_timeout was NOT passed assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs # Test with HTTP (should not have ssl_shutdown_timeout anyway) req = make_client_request("GET", URL("http://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs await conn.close() @pytest.mark.skipif( sys.version_info < (3, 11), reason="ssl_shutdown_timeout requires Python 3.11+" ) async def test_tcp_connector_ssl_shutdown_timeout_nonzero_passed( # type: ignore[misc] loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock, make_client_request: _RequestMaker, ) -> None: """Test that non-zero ssl_shutdown_timeout IS passed to create_connection on Python 3.11+.""" with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=5.0) with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() # Test with HTTPS req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): # Verify ssl_shutdown_timeout WAS passed assert create_connection.call_args.kwargs["ssl_shutdown_timeout"] == 5.0 # Test with HTTP (should not have ssl_shutdown_timeout) req = make_client_request("GET", URL("http://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs await conn.close() async def test_tcp_connector_close_abort_ssl_connections_in_conns( loop: asyncio.AbstractEventLoop, ) -> None: """Test that SSL connections in _conns are aborted when ssl_shutdown_timeout=0.""" with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=0) # Create mock SSL protocol proto = mock.create_autospec(ResponseHandler, instance=True) proto.closed = None # Create mock SSL transport transport = mock.Mock() transport.get_extra_info.return_value = mock.Mock() # Returns SSL context proto.transport = transport # Add the protocol to _conns key = ConnectionKey("host", 443, True, True, None, None, None) conn._conns[key] = deque([(proto, loop.time())]) # Close the connector await conn.close() # Verify abort was called for SSL connection proto.abort.assert_called_once() proto.close.assert_not_called() async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.TCPConnector() assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"} async def test_start_tls_exception_with_ssl_shutdown_timeout_zero( loop: asyncio.AbstractEventLoop, ) -> None: """Test _start_tls_connection exception handling with ssl_shutdown_timeout=0.""" with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=0) underlying_transport = mock.Mock() req = mock.Mock() req.server_hostname = None req.host = "example.com" req.is_ssl = mock.Mock(return_value=True) # Patch _get_ssl_context to return a valid context and make start_tls fail with ( mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) # Should abort, not close underlying_transport.abort.assert_called_once() underlying_transport.close.assert_not_called() @pytest.mark.skipif( sys.version_info < (3, 11), reason="Use test_start_tls_exception_with_ssl_shutdown_timeout_nonzero_pre_311 for Python < 3.11", ) async def test_start_tls_exception_with_ssl_shutdown_timeout_nonzero( loop: asyncio.AbstractEventLoop, ) -> None: """Test _start_tls_connection exception handling with ssl_shutdown_timeout>0.""" with pytest.warns( DeprecationWarning, match="ssl_shutdown_timeout parameter is deprecated" ): conn = aiohttp.TCPConnector(ssl_shutdown_timeout=1.0) underlying_transport = mock.Mock() req = mock.Mock() req.server_hostname = None req.host = "example.com" req.is_ssl = mock.Mock(return_value=True) # Patch _get_ssl_context to return a valid context and make start_tls fail with ( mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) # Should close, not abort underlying_transport.close.assert_called_once() underlying_transport.abort.assert_not_called() @pytest.mark.skipif( sys.version_info >= (3, 11), reason="This test is for Python < 3.11 runtime warning behavior", ) async def test_start_tls_exception_with_ssl_shutdown_timeout_nonzero_pre_311( loop: asyncio.AbstractEventLoop, ) -> None: """Test _start_tls_connection exception handling with ssl_shutdown_timeout>0 on Python < 3.11.""" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") conn = aiohttp.TCPConnector(ssl_shutdown_timeout=1.0) # Should have both deprecation and runtime warnings assert len(w) == 2 assert any(issubclass(warn.category, DeprecationWarning) for warn in w) assert any(issubclass(warn.category, RuntimeWarning) for warn in w) underlying_transport = mock.Mock() req = mock.Mock() req.server_hostname = None req.host = "example.com" req.is_ssl = mock.Mock(return_value=True) # Patch _get_ssl_context to return a valid context and make start_tls fail with ( mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) # Should close, not abort underlying_transport.close.assert_called_once() underlying_transport.abort.assert_not_called() def test_client_timeout_total_zero_raises() -> None: """Test that ClientTimeout(total=0) raises ValueError. Related to https://github.com/aio-libs/aiohttp/issues/11859 Using total=0 to disable timeouts is no longer supported in v4, use None instead. """ with pytest.raises(ValueError, match="total timeout must be a positive number"): ClientTimeout(total=0) def test_client_timeout_total_none_is_valid() -> None: """Test that ClientTimeout(total=None) is still valid for disabling timeouts.""" timeout = ClientTimeout(total=None) assert timeout.total is None async def test_invalid_ssl_param() -> None: with pytest.raises(TypeError): aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type] async def test_tcp_connector_ctor_fingerprint_valid( loop: asyncio.AbstractEventLoop, ) -> None: valid = aiohttp.Fingerprint(hashlib.sha256(b"foo").digest()) conn = aiohttp.TCPConnector(ssl=valid) assert conn._ssl is valid await conn.close() async def test_insecure_fingerprint_md5(loop: asyncio.AbstractEventLoop) -> None: with pytest.raises(ValueError): aiohttp.TCPConnector(ssl=aiohttp.Fingerprint(hashlib.md5(b"foo").digest())) async def test_insecure_fingerprint_sha1(loop: asyncio.AbstractEventLoop) -> None: with pytest.raises(ValueError): aiohttp.TCPConnector(ssl=aiohttp.Fingerprint(hashlib.sha1(b"foo").digest())) async def test_tcp_connector_clear_dns_cache(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.TCPConnector() h1: ResolveResult = { "hostname": "a", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } h2: ResolveResult = { "hostname": "a", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } hosts = [h1, h2] conn._cached_hosts.add(("localhost", 123), hosts) conn._cached_hosts.add(("localhost", 124), hosts) conn.clear_dns_cache("localhost", 123) with pytest.raises(KeyError): conn._cached_hosts.next_addrs(("localhost", 123)) assert conn._cached_hosts.next_addrs(("localhost", 124)) == hosts # Remove removed element is OK conn.clear_dns_cache("localhost", 123) with pytest.raises(KeyError): conn._cached_hosts.next_addrs(("localhost", 123)) conn.clear_dns_cache() with pytest.raises(KeyError): conn._cached_hosts.next_addrs(("localhost", 124)) await conn.close() async def test_tcp_connector_clear_dns_cache_bad_args( loop: asyncio.AbstractEventLoop, ) -> None: conn = aiohttp.TCPConnector() with pytest.raises(ValueError): conn.clear_dns_cache("localhost") await conn.close() async def test___get_ssl_context1() -> None: conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = False assert conn._get_ssl_context(req) is None await conn.close() async def test___get_ssl_context2() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = True req.ssl = ctx assert conn._get_ssl_context(req) is ctx await conn.close() async def test___get_ssl_context3() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) conn = aiohttp.TCPConnector(ssl=ctx) req = mock.Mock() req.is_ssl.return_value = True req.ssl = True assert conn._get_ssl_context(req) is ctx await conn.close() async def test___get_ssl_context4() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) conn = aiohttp.TCPConnector(ssl=ctx) req = mock.Mock() req.is_ssl.return_value = True req.ssl = False assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED await conn.close() async def test___get_ssl_context5() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) conn = aiohttp.TCPConnector(ssl=ctx) req = mock.Mock() req.is_ssl.return_value = True req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest()) assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED await conn.close() async def test___get_ssl_context6() -> None: conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = True req.ssl = True assert conn._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED await conn.close() async def test_ssl_context_once() -> None: """Test the ssl context is created only once and shared between connectors.""" conn1 = aiohttp.TCPConnector() conn2 = aiohttp.TCPConnector() conn3 = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = True req.ssl = True assert conn1._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED assert conn2._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED assert conn3._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED async def test_close_twice(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: proto: ResponseHandler = create_mocked_conn(loop) conn = aiohttp.BaseConnector() conn._conns[key] = deque([(proto, 0)]) await conn.close() assert not conn._conns assert proto.close.called # type: ignore[attr-defined] assert conn.closed conn._conns = "Invalid" # type: ignore[assignment] # fill with garbage await conn.close() assert conn.closed async def test_close_cancels_cleanup_handle( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector() conn._release(key, create_mocked_conn(should_close=False)) assert conn._cleanup_handle is not None await conn.close() assert conn._cleanup_handle is None async def test_close_cancels_resolve_host( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: cancelled = False async def delay_resolve(*args: object, **kwargs: object) -> None: """Delay resolve() task in order to test cancellation.""" nonlocal cancelled try: await asyncio.sleep(10) except asyncio.CancelledError: cancelled = True raise conn = aiohttp.TCPConnector() req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with mock.patch.object(conn._resolver, "resolve", delay_resolve): t = asyncio.create_task(conn.connect(req, [], ClientTimeout())) # Let it create the internal task await asyncio.sleep(0) # Let that task start running await asyncio.sleep(0) # We now have a task being tracked and can ensure that .close() cancels it. assert len(conn._resolve_host_tasks) == 1 await conn.close() assert cancelled assert len(conn._resolve_host_tasks) == 0 with suppress(asyncio.CancelledError): await t async def test_multiple_dns_resolution_requests_success( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that multiple DNS resolution requests are handled correctly.""" async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" for _ in range(3): await asyncio.sleep(0) return [ { "hostname": "localhost", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, }, ] conn = aiohttp.TCPConnector(force_close=True) req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", side_effect=OSError(1, "Forced connection to fail"), ), ): task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) # Let it create the internal task await asyncio.sleep(0) # Let that task start running await asyncio.sleep(0) # Ensure the task is running assert len(conn._resolve_host_tasks) == 1 task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) with pytest.raises( aiohttp.ClientConnectorError, match="Forced connection to fail" ): await task1 # Verify the the task is finished assert len(conn._resolve_host_tasks) == 0 with pytest.raises( aiohttp.ClientConnectorError, match="Forced connection to fail" ): await task2 with pytest.raises( aiohttp.ClientConnectorError, match="Forced connection to fail" ): await task3 async def test_multiple_dns_resolution_requests_failure( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that DNS resolution failure for multiple requests is handled correctly.""" async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" for _ in range(3): await asyncio.sleep(0) raise OSError(None, "DNS Resolution mock failure") conn = aiohttp.TCPConnector(force_close=True) req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", side_effect=OSError(1, "Forced connection to fail"), ), ): task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) # Let it create the internal task await asyncio.sleep(0) # Let that task start running await asyncio.sleep(0) # Ensure the task is running assert len(conn._resolve_host_tasks) == 1 task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) with pytest.raises( aiohttp.ClientConnectorError, match="DNS Resolution mock failure" ): await task1 # Verify the the task is finished assert len(conn._resolve_host_tasks) == 0 with pytest.raises( aiohttp.ClientConnectorError, match="DNS Resolution mock failure" ): await task2 with pytest.raises( aiohttp.ClientConnectorError, match="DNS Resolution mock failure" ): await task3 async def test_multiple_dns_resolution_requests_cancelled( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that DNS resolution cancellation does not affect other tasks.""" async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" for _ in range(3): await asyncio.sleep(0) raise OSError(None, "DNS Resolution mock failure") conn = aiohttp.TCPConnector(force_close=True) req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", side_effect=OSError(1, "Forced connection to fail"), ), ): task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) # Let it create the internal task await asyncio.sleep(0) # Let that task start running await asyncio.sleep(0) # Ensure the task is running assert len(conn._resolve_host_tasks) == 1 task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) task1.cancel() with pytest.raises(asyncio.CancelledError): await task1 with pytest.raises( aiohttp.ClientConnectorError, match="DNS Resolution mock failure" ): await task2 with pytest.raises( aiohttp.ClientConnectorError, match="DNS Resolution mock failure" ): await task3 # Verify the the task is finished assert len(conn._resolve_host_tasks) == 0 async def test_multiple_dns_resolution_requests_first_cancelled( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that first DNS resolution cancellation does not make other resolutions fail.""" async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" for _ in range(3): await asyncio.sleep(0) return [ { "hostname": "localhost", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, }, ] conn = aiohttp.TCPConnector(force_close=True) req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", side_effect=OSError(1, "Forced connection to fail"), ), ): task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) # Let it create the internal task await asyncio.sleep(0) # Let that task start running await asyncio.sleep(0) # Ensure the task is running assert len(conn._resolve_host_tasks) == 1 task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) task1.cancel() with pytest.raises(asyncio.CancelledError): await task1 # The second and third tasks should still make the connection # even if the first one is cancelled with pytest.raises( aiohttp.ClientConnectorError, match="Forced connection to fail" ): await task2 with pytest.raises( aiohttp.ClientConnectorError, match="Forced connection to fail" ): await task3 # Verify the the task is finished assert len(conn._resolve_host_tasks) == 0 async def test_multiple_dns_resolution_requests_first_fails_second_successful( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that first DNS resolution fails the first time and is successful the second time.""" attempt = 0 async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" nonlocal attempt for _ in range(3): await asyncio.sleep(0) attempt += 1 if attempt == 1: raise OSError(None, "DNS Resolution mock failure") return [ { "hostname": "localhost", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, }, ] conn = aiohttp.TCPConnector(force_close=True) req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( mock.patch.object(conn._resolver, "resolve", delay_resolve), mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", side_effect=OSError(1, "Forced connection to fail"), ), ): task1 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) # Let it create the internal task await asyncio.sleep(0) # Let that task start running await asyncio.sleep(0) # Ensure the task is running assert len(conn._resolve_host_tasks) == 1 task2 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) with pytest.raises( aiohttp.ClientConnectorError, match="DNS Resolution mock failure" ): await task1 assert len(conn._resolve_host_tasks) == 0 # The second task should also get the dns resolution failure with pytest.raises( aiohttp.ClientConnectorError, match="DNS Resolution mock failure" ): await task2 # The third task is created after the resolution finished so # it should try again and succeed task3 = asyncio.create_task(conn.connect(req, [], ClientTimeout())) # Let it create the internal task await asyncio.sleep(0) # Let that task start running await asyncio.sleep(0) # Ensure the task is running assert len(conn._resolve_host_tasks) == 1 with pytest.raises( aiohttp.ClientConnectorError, match="Forced connection to fail" ): await task3 # Verify the the task is finished assert len(conn._resolve_host_tasks) == 0 async def test_close_abort_closed_transports(loop: asyncio.AbstractEventLoop) -> None: tr = mock.Mock() conn = aiohttp.BaseConnector() conn._cleanup_closed_transports.append(tr) await conn.close() assert not conn._cleanup_closed_transports assert tr.abort.called assert conn.closed @pytest.mark.usefixtures("enable_cleanup_closed") async def test_close_cancels_cleanup_closed_handle( loop: asyncio.AbstractEventLoop, ) -> None: conn = aiohttp.BaseConnector(enable_cleanup_closed=True) assert conn._cleanup_closed_handle is not None await conn.close() assert conn._cleanup_closed_handle is None async def test_ctor_with_default_loop(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() assert loop is conn._loop await conn.close() async def test_base_connector_allows_high_level_protocols( loop: asyncio.AbstractEventLoop, ) -> None: conn = aiohttp.BaseConnector() assert conn.allowed_protocol_schema_set == { "", "http", "https", "ws", "wss", } async def test_connect_with_limit( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) conn = aiohttp.BaseConnector(limit=1, limit_per_host=10) conn._conns[key] = deque([(proto, loop.time())]) with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): connection1 = await conn.connect(req, [], ClientTimeout()) assert connection1._protocol == proto assert 1 == len(conn._acquired) assert proto in conn._acquired assert key in conn._acquired_per_host assert proto in conn._acquired_per_host[key] acquired = False async def f() -> None: nonlocal acquired connection2 = await conn.connect(req, [], ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) connection2.release() task = loop.create_task(f()) await asyncio.sleep(0.01) assert not acquired connection1.release() await asyncio.sleep(0) assert acquired await task # type: ignore[unreachable] await conn.close() async def test_connect_queued_operation_tracing( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_connection_queued_start = mock.AsyncMock() on_connection_queued_end = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_connection_queued_start.append(on_connection_queued_start) trace_config.on_connection_queued_end.append(on_connection_queued_end) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request( "GET", URL("http://localhost1:80"), loop=loop, response_class=mock.Mock() ) conn = aiohttp.BaseConnector(limit=1) conn._conns[key] = deque([(proto, loop.time())]) with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): connection1 = await conn.connect(req, traces, ClientTimeout()) async def f() -> None: connection2 = await conn.connect(req, traces, ClientTimeout()) on_connection_queued_start.assert_called_with( session, trace_config_ctx, aiohttp.TraceConnectionQueuedStartParams() ) on_connection_queued_end.assert_called_with( session, trace_config_ctx, aiohttp.TraceConnectionQueuedEndParams() ) connection2.release() task = asyncio.ensure_future(f()) await asyncio.sleep(0.01) connection1.release() await task await conn.close() async def test_connect_reuseconn_tracing( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_connection_reuseconn = mock.AsyncMock() trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) ) trace_config.on_connection_reuseconn.append(on_connection_reuseconn) trace_config.freeze() traces = [Trace(session, trace_config, trace_config.trace_config_ctx())] proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) conn = aiohttp.BaseConnector(limit=1) conn._conns[key] = deque([(proto, loop.time())]) conn2 = await conn.connect(req, traces, ClientTimeout()) conn2.release() on_connection_reuseconn.assert_called_with( session, trace_config_ctx, aiohttp.TraceConnectionReuseconnParams() ) await conn.close() @pytest.mark.parametrize( "test_case,wait_for_con,expect_proxy_auth_header", [ ("use_proxy_with_embedded_auth", False, True), ("use_proxy_with_auth_headers", True, True), ("use_proxy_no_auth", False, False), ("dont_use_proxy", False, False), ], ) async def test_connect_reuse_proxy_headers( # type: ignore[misc] loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, test_case: str, wait_for_con: bool, expect_proxy_auth_header: bool, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True if test_case != "dont_use_proxy": proxy = ( URL("http://user:password@example.com") if test_case == "use_proxy_with_embedded_auth" else URL("http://example.com") ) proxy_headers = ( CIMultiDict({hdrs.AUTHORIZATION: "Basic dXNlcjpwYXNzd29yZA=="}) if test_case == "use_proxy_with_auth_headers" else None ) else: proxy = None proxy_headers = None key = ConnectionKey( "localhost", 80, False, True, proxy, None, hash(tuple(proxy_headers.items())) if proxy_headers else None, ) req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock(), proxy=proxy, proxy_headers=proxy_headers, ) conn = aiohttp.BaseConnector(limit=1) async def _create_con(*args: Any, **kwargs: Any) -> None: conn._conns[key] = deque([(proto, loop.time())]) with contextlib.ExitStack() as stack: if wait_for_con: # Simulate no available connections stack.enter_context( mock.patch.object( conn, "_available_connections", autospec=True, return_value=0 ) ) # Upon waiting for a connection, populate _conns with our proto, # mocking a connection becoming immediately available stack.enter_context( mock.patch.object( conn, "_wait_for_available_connection", autospec=True, side_effect=_create_con, ) ) else: await _create_con() # Call function to test conn2 = await conn.connect(req, [], ClientTimeout()) conn2.release() await conn.close() if expect_proxy_auth_header: assert req.headers[hdrs.PROXY_AUTHORIZATION] == "Basic dXNlcjpwYXNzd29yZA==" else: assert hdrs.PROXY_AUTHORIZATION not in req.headers async def test_connect_with_limit_and_limit_per_host( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1000, limit_per_host=1) conn._conns[key] = deque([(proto, loop.time())]) with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): acquired = False connection1 = await conn.connect(req, [], ClientTimeout()) async def f() -> None: nonlocal acquired connection2 = await conn.connect(req, [], ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert 1 == len(conn._acquired_per_host[key]) connection2.release() task = loop.create_task(f()) await asyncio.sleep(0.01) assert not acquired connection1.release() await asyncio.sleep(0) assert acquired await task # type: ignore[unreachable] await conn.close() async def test_connect_with_no_limit_and_limit_per_host( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://localhost1:80"), loop=loop) conn = aiohttp.BaseConnector(limit=0, limit_per_host=1) conn._conns[key] = deque([(proto, loop.time())]) with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): acquired = False connection1 = await conn.connect(req, [], ClientTimeout()) async def f() -> None: nonlocal acquired connection2 = await conn.connect(req, [], ClientTimeout()) acquired = True connection2.release() task = loop.create_task(f()) await asyncio.sleep(0.01) assert not acquired connection1.release() await asyncio.sleep(0) assert acquired await task # type: ignore[unreachable] await conn.close() async def test_connect_with_no_limits( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector(limit=0, limit_per_host=0) conn._conns[key] = deque([(proto, loop.time())]) with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): acquired = False connection1 = await conn.connect(req, [], ClientTimeout()) async def f() -> None: nonlocal acquired connection2 = await conn.connect(req, [], ClientTimeout()) acquired = True assert 1 == len(conn._acquired) assert not conn._acquired_per_host connection2.release() task = loop.create_task(f()) await asyncio.sleep(0.01) assert acquired connection1.release() await task await conn.close() async def test_connect_with_limit_cancelled( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1) conn._conns[key] = deque([(proto, loop.time())]) with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): connection = await conn.connect(req, [], ClientTimeout()) assert connection._protocol == proto assert connection.transport == proto.transport assert 1 == len(conn._acquired) with pytest.raises(asyncio.TimeoutError): # limit exhausted await asyncio.wait_for(conn.connect(req, [], ClientTimeout()), 0.01) connection.close() await conn.close() async def test_connect_with_capacity_release_waiters( loop: asyncio.AbstractEventLoop, ) -> None: async def check_with_exc(err: Exception) -> None: conn = aiohttp.BaseConnector(limit=1) with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, side_effect=err ): with pytest.raises(Exception): req = mock.Mock() await conn.connect(req, [], ClientTimeout()) assert not conn._waiters await conn.close() await check_with_exc(OSError(1, "permission error")) await check_with_exc(RuntimeError()) await check_with_exc(asyncio.TimeoutError()) async def test_connect_with_limit_concurrent( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: proto = create_mocked_conn(loop) proto.should_close = False proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) max_connections = 2 num_connections = 0 conn = aiohttp.BaseConnector(limit=max_connections) # Use a real coroutine for _create_connection; a mock would mask # problems that only happen when the method yields. async def create_connection( req: object, traces: object, timeout: object ) -> ResponseHandler: nonlocal num_connections num_connections += 1 await asyncio.sleep(0) # Make a new transport mock each time because acquired # transports are stored in a set. Reusing the same object # messes with the count. proto = create_mocked_conn(loop, should_close=False) proto.is_connected.return_value = True return proto # Simulate something like a crawler. It opens a connection, does # something with it, closes it, then creates tasks that make more # connections and waits for them to finish. The crawler is started # with multiple concurrent requests and stops when it hits a # predefined maximum number of requests. max_requests = 50 num_requests = 0 start_requests = max_connections + 1 async def f(start: bool = True) -> None: nonlocal num_requests if num_requests == max_requests: return num_requests += 1 if not start: connection = await conn.connect(req, [], ClientTimeout()) await asyncio.sleep(0) connection.release() await asyncio.sleep(0) tasks = [loop.create_task(f(start=False)) for i in range(start_requests)] await asyncio.wait(tasks) with mock.patch.object(conn, "_create_connection", create_connection): await f() await conn.close() assert max_connections == num_connections async def test_connect_waiters_cleanup( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1) with mock.patch.object(conn, "_available_connections", return_value=0): t = loop.create_task(conn.connect(req, [], ClientTimeout())) await asyncio.sleep(0) assert conn._waiters.keys() t.cancel() await asyncio.sleep(0) assert not conn._waiters.keys() await conn.close() async def test_connect_waiters_cleanup_key_error( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1, limit_per_host=10) with mock.patch.object( conn, "_available_connections", autospec=True, spec_set=True, return_value=0 ): t = loop.create_task(conn.connect(req, [], ClientTimeout())) await asyncio.sleep(0) assert conn._waiters.keys() # we delete the entry explicitly before the # canceled connection grabs the loop again, we # must expect a none failure termination conn._waiters.clear() t.cancel() await asyncio.sleep(0) assert not conn._waiters.keys() == [] await conn.close() async def test_close_with_acquired_connection( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1) conn._conns[key] = deque([(proto, loop.time())]) with mock.patch.object( conn, "_create_connection", autospec=True, spec_set=True, return_value=proto ): connection = await conn.connect(req, [], ClientTimeout()) assert 1 == len(conn._acquired) await conn.close() assert 0 == len(conn._acquired) assert conn.closed proto.close.assert_called_with() assert not connection.closed connection.close() assert connection.closed async def test_default_force_close(loop: asyncio.AbstractEventLoop) -> None: connector = aiohttp.BaseConnector() assert not connector.force_close await connector.close() async def test_limit_property(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector(limit=15) assert 15 == conn.limit await conn.close() async def test_limit_per_host_property(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector(limit_per_host=15) assert 15 == conn.limit_per_host await conn.close() async def test_limit_property_default(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() assert conn.limit == 100 await conn.close() async def test_limit_per_host_property_default(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector() assert conn.limit_per_host == 0 await conn.close() async def test_force_close_and_explicit_keep_alive( loop: asyncio.AbstractEventLoop, ) -> None: aiohttp.BaseConnector(force_close=True) aiohttp.BaseConnector(force_close=True, keepalive_timeout=None) with pytest.raises(ValueError): aiohttp.BaseConnector(keepalive_timeout=30, force_close=True) async def test_error_on_connection( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(limit=1, limit_per_host=10) req = mock.Mock() req.connection_key = key proto = create_mocked_conn(loop) i = 0 fut = loop.create_future() exc = OSError() async def create_connection( req: object, traces: object, timeout: object ) -> ResponseHandler: nonlocal i i += 1 if i == 1: await fut raise exc elif i == 2: return proto assert False with mock.patch.object(conn, "_create_connection", create_connection): t1 = loop.create_task(conn.connect(req, [], ClientTimeout())) t2 = loop.create_task(conn.connect(req, [], ClientTimeout())) await asyncio.sleep(0) assert not t1.done() assert not t2.done() assert len(conn._acquired_per_host[key]) == 1 fut.set_result(None) with pytest.raises(OSError): await t1 ret = await t2 assert len(conn._acquired_per_host[key]) == 1 assert ret._key == key assert ret.protocol == proto assert proto in conn._acquired ret.release() await conn.close() async def test_cancelled_waiter(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.BaseConnector(limit=1) req = mock.Mock() req.connection_key = "key" proto = create_mocked_conn(loop) async def create_connection(req: object, traces: object = None) -> ResponseHandler: await asyncio.sleep(1) return proto with mock.patch.object(conn, "_create_connection", create_connection): conn._acquired.add(proto) conn2 = loop.create_task(conn.connect(req, [], ClientTimeout())) await asyncio.sleep(0) conn2.cancel() with pytest.raises(asyncio.CancelledError): await conn2 await conn.close() async def test_error_on_connection_with_cancelled_waiter( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: conn = aiohttp.BaseConnector(limit=1, limit_per_host=10) req = mock.Mock() req.connection_key = key proto = create_mocked_conn() i = 0 fut1 = loop.create_future() fut2 = loop.create_future() exc = OSError() async def create_connection( req: object, traces: object, timeout: object ) -> ResponseHandler: nonlocal i i += 1 if i == 1: await fut1 raise exc if i == 2: await fut2 elif i == 3: return proto assert False with mock.patch.object(conn, "_create_connection", create_connection): t1 = loop.create_task(conn.connect(req, [], ClientTimeout())) t2 = loop.create_task(conn.connect(req, [], ClientTimeout())) t3 = loop.create_task(conn.connect(req, [], ClientTimeout())) await asyncio.sleep(0) assert not t1.done() assert not t2.done() assert len(conn._acquired_per_host[key]) == 1 fut1.set_result(None) fut2.cancel() with pytest.raises(OSError): await t1 with pytest.raises(asyncio.CancelledError): await t2 ret = await t3 assert len(conn._acquired_per_host[key]) == 1 assert ret._key == key assert ret.protocol == proto assert proto in conn._acquired ret.release() await conn.close() async def test_tcp_connector( aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) r = await client.get("/") assert r.status == 200 @pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") async def test_unix_connector_not_found( # type: ignore[misc] loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex) req = make_client_request("GET", URL("http://www.python.org"), loop=loop) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) @pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") async def test_unix_connector_permission( # type: ignore[misc] loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: m = mock.AsyncMock(side_effect=PermissionError()) with mock.patch.object(loop, "create_unix_connection", m): connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex) req = make_client_request("GET", URL("http://www.python.org"), loop=loop) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_connector_wrong_loop( selector_loop: asyncio.AbstractEventLoop, pipe_name: str ) -> None: with pytest.raises(RuntimeError): aiohttp.NamedPipeConnector(pipe_name) @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_connector_not_found( # type: ignore[misc] proactor_loop: asyncio.AbstractEventLoop, pipe_name: str, make_client_request: _RequestMaker, ) -> None: asyncio.set_event_loop(proactor_loop) connector = aiohttp.NamedPipeConnector(pipe_name) req = make_client_request("GET", URL("http://www.python.org"), loop=proactor_loop) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_connector_permission( # type: ignore[misc] proactor_loop: asyncio.AbstractEventLoop, pipe_name: str, make_client_request: _RequestMaker, ) -> None: m = mock.AsyncMock(side_effect=PermissionError()) with mock.patch.object(proactor_loop, "create_pipe_connection", m): asyncio.set_event_loop(proactor_loop) connector = aiohttp.NamedPipeConnector(pipe_name) req = make_client_request( "GET", URL("http://www.python.org"), loop=proactor_loop ) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) async def test_default_use_dns_cache() -> None: conn = aiohttp.TCPConnector() assert conn.use_dns_cache await conn.close() async def test_resolver_not_called_with_address_is_ip( loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: resolver = mock.MagicMock() connector = aiohttp.TCPConnector(resolver=resolver) req = make_client_request( "GET", URL(f"http://127.0.0.1:{unused_port()}"), loop=loop, response_class=mock.Mock(), ) with pytest.raises(OSError): await connector.connect(req, [], ClientTimeout()) resolver.resolve.assert_not_called() await connector.close() async def test_tcp_connector_raise_connector_ssl_error( aiohttp_server: AiohttpServer, ssl_ctx: ssl.SSLContext ) -> None: async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() app.router.add_get("/", handler) srv = await aiohttp_server(app, ssl=ssl_ctx) port = unused_port() conn = aiohttp.TCPConnector(local_addr=("127.0.0.1", port)) session = aiohttp.ClientSession(connector=conn) url = srv.make_url("/") err = aiohttp.ClientConnectorCertificateError with pytest.raises(err) as ctx: await session.get(url) assert isinstance(ctx.value, aiohttp.ClientConnectorCertificateError) assert isinstance(ctx.value.certificate_error, ssl.SSLError) await session.close() await conn.close() @pytest.mark.parametrize( "host", ( pytest.param("127.0.0.1", id="ip address"), pytest.param("localhost", id="domain name"), pytest.param("localhost.", id="fully-qualified domain name"), pytest.param( "localhost...", id="fully-qualified domain name with multiple trailing dots" ), pytest.param("príklad.localhost.", id="idna fully-qualified domain name"), ), ) async def test_tcp_connector_do_not_raise_connector_ssl_error( aiohttp_server: AiohttpServer, ssl_ctx: ssl.SSLContext, client_ssl_ctx: ssl.SSLContext, host: str, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) srv = await aiohttp_server(app, ssl=ssl_ctx) port = unused_port() conn = aiohttp.TCPConnector(local_addr=("127.0.0.1", port)) # resolving something.localhost with the real DNS resolver does not work on macOS, so we have a stub. async def _resolve_host( host: str, port: int, traces: object = None ) -> list[ResolveResult]: return [ { "hostname": host, "host": "127.0.0.1", "port": port, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, }, { "hostname": host, "host": "::1", "port": port, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, }, ] with mock.patch.object( conn, "_resolve_host", autospec=True, spec_set=True, side_effect=_resolve_host ): session = aiohttp.ClientSession(connector=conn) url = srv.make_url("/") r = await session.get(url.with_host(host), ssl=client_ssl_ctx) r.release() first_conn = next(iter(conn._conns.values()))[0][0] assert first_conn.transport is not None _sslcontext = first_conn.transport._ssl_protocol._sslcontext # type: ignore[attr-defined] assert _sslcontext is client_ssl_ctx r.close() await session.close() await conn.close() async def test_tcp_connector_uses_provided_local_addr( aiohttp_server: AiohttpServer, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) srv = await aiohttp_server(app) port = unused_port() conn = aiohttp.TCPConnector(local_addr=("127.0.0.1", port)) session = aiohttp.ClientSession(connector=conn) url = srv.make_url("/") r = await session.get(url) r.release() first_conn = next(iter(conn._conns.values()))[0][0] assert first_conn.transport is not None assert first_conn.transport.get_extra_info("sockname") == ("127.0.0.1", port) r.close() await session.close() await conn.close() async def test_unix_connector( unix_server: Callable[[web.Application], Awaitable[None]], unix_sockname: str ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) await unix_server(app) url = "http://127.0.0.1/" connector = aiohttp.UnixConnector(unix_sockname) assert unix_sockname == connector.path assert connector.allowed_protocol_schema_set == { "", "http", "https", "ws", "wss", "unix", } session = ClientSession(connector=connector) r = await session.get(url) assert r.status == 200 r.close() await session.close() @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_connector( proactor_loop: asyncio.AbstractEventLoop, named_pipe_server: Callable[[web.Application], Awaitable[None]], pipe_name: str, ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) await named_pipe_server(app) url = "http://this-does-not-matter.com" connector = aiohttp.NamedPipeConnector(pipe_name) assert pipe_name == connector.path assert connector.allowed_protocol_schema_set == { "", "http", "https", "ws", "wss", "npipe", } session = ClientSession(connector=connector) r = await session.get(url) assert r.status == 200 r.close() await session.close() class TestDNSCacheTable: host1 = ("localhost", 80) host2 = ("foo", 80) result1: ResolveResult = { "hostname": "localhost", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } result2: ResolveResult = { "hostname": "foo", "host": "127.0.0.2", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } @pytest.fixture def dns_cache_table(self) -> _DNSCacheTable: return _DNSCacheTable() def test_next_addrs_basic(self, dns_cache_table: _DNSCacheTable) -> None: dns_cache_table.add(self.host1, [self.result1]) dns_cache_table.add(self.host2, [self.result2]) addrs = dns_cache_table.next_addrs(self.host1) assert addrs == [self.result1] addrs = dns_cache_table.next_addrs(self.host2) assert addrs == [self.result2] with pytest.raises(KeyError): dns_cache_table.next_addrs(("no-such-host", 80)) def test_remove(self, dns_cache_table: _DNSCacheTable) -> None: dns_cache_table.add(self.host1, [self.result1]) dns_cache_table.remove(self.host1) with pytest.raises(KeyError): dns_cache_table.next_addrs(self.host1) def test_clear(self, dns_cache_table: _DNSCacheTable) -> None: dns_cache_table.add(self.host1, [self.result1]) dns_cache_table.clear() with pytest.raises(KeyError): dns_cache_table.next_addrs(self.host1) def test_not_expired_ttl_None(self, dns_cache_table: _DNSCacheTable) -> None: dns_cache_table.add(self.host1, [self.result1]) assert not dns_cache_table.expired(self.host1) def test_not_expired_ttl(self) -> None: dns_cache_table = _DNSCacheTable(ttl=0.1) dns_cache_table.add(self.host1, [self.result1]) assert not dns_cache_table.expired(self.host1) def test_expired_ttl(self, monkeypatch: pytest.MonkeyPatch) -> None: dns_cache_table = _DNSCacheTable(ttl=1) monkeypatch.setattr("aiohttp.connector.monotonic", lambda: 1) dns_cache_table.add(self.host1, [self.result1]) monkeypatch.setattr("aiohttp.connector.monotonic", lambda: 2) assert not dns_cache_table.expired(self.host1) monkeypatch.setattr("aiohttp.connector.monotonic", lambda: 3) assert dns_cache_table.expired(self.host1) def test_never_expire(self, monkeypatch: pytest.MonkeyPatch) -> None: dns_cache_table = _DNSCacheTable(ttl=None) monkeypatch.setattr("aiohttp.connector.monotonic", lambda: 1) dns_cache_table.add(self.host1, [self.result1]) monkeypatch.setattr("aiohttp.connector.monotonic", lambda: 10000000) assert not dns_cache_table.expired(self.host1) def test_always_expire(self, monkeypatch: pytest.MonkeyPatch) -> None: dns_cache_table = _DNSCacheTable(ttl=0) monkeypatch.setattr("aiohttp.connector.monotonic", lambda: 1) dns_cache_table.add(self.host1, [self.result1]) monkeypatch.setattr("aiohttp.connector.monotonic", lambda: 1.00001) assert dns_cache_table.expired(self.host1) def test_next_addrs(self, dns_cache_table: _DNSCacheTable) -> None: result3: ResolveResult = { "hostname": "foo", "host": "127.0.0.3", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } dns_cache_table.add(self.host2, [self.result1, self.result2, result3]) # Each calls to next_addrs return the hosts using # a round robin strategy. addrs = dns_cache_table.next_addrs(self.host2) assert addrs == [self.result1, self.result2, result3] addrs = dns_cache_table.next_addrs(self.host2) assert addrs == [self.result2, result3, self.result1] addrs = dns_cache_table.next_addrs(self.host2) assert addrs == [result3, self.result1, self.result2] addrs = dns_cache_table.next_addrs(self.host2) assert addrs == [self.result1, self.result2, result3] def test_next_addrs_single(self, dns_cache_table: _DNSCacheTable) -> None: dns_cache_table.add(self.host2, [self.result1]) addrs = dns_cache_table.next_addrs(self.host2) assert addrs == [self.result1] addrs = dns_cache_table.next_addrs(self.host2) assert addrs == [self.result1] def test_max_size_eviction(self) -> None: table = _DNSCacheTable(max_size=2) table.add(self.host1, [self.result1]) table.add(self.host2, [self.result2]) host3 = ("example.com", 80) result3: ResolveResult = { **self.result1, "hostname": "example.com", "host": "1.2.3.4", } table.add(host3, [result3]) assert len(table._addrs_rr) == 2 assert self.host1 not in table._addrs_rr assert host3 in table._addrs_rr def test_lru_eviction(self) -> None: table = _DNSCacheTable(max_size=2) table.add(self.host1, [self.result1]) table.add(self.host2, [self.result2]) table.next_addrs(self.host1) host3 = ("example.com", 80) result3: ResolveResult = { **self.result1, "hostname": "example.com", "host": "1.2.3.4", } table.add(host3, [result3]) assert self.host1 in table._addrs_rr assert self.host2 not in table._addrs_rr def test_lru_eviction_add(self) -> None: table = _DNSCacheTable(max_size=2) table.add(self.host1, [self.result1]) table.add(self.host2, [self.result2]) # Re-add, thus making host1 the most recently used. table.add(self.host1, [self.result1]) host3 = ("example.com", 80) result3: ResolveResult = { **self.result1, "hostname": "example.com", "host": "1.2.3.4", } table.add(host3, [result3]) assert self.host1 in table._addrs_rr assert self.host2 not in table._addrs_rr async def test_connector_cache_trace_race() -> None: class DummyTracer(Trace): def __init__(self) -> None: """Dummy""" async def send_dns_cache_hit(self, *args: object, **kwargs: object) -> None: connector._cached_hosts.remove(("", 0)) token: ResolveResult = { "hostname": "localhost", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } connector = TCPConnector() connector._cached_hosts.add(("", 0), [token]) traces = [DummyTracer()] assert await connector._resolve_host("", 0, traces) == [token] await connector.close() async def test_connector_throttle_trace_race(loop: asyncio.AbstractEventLoop) -> None: key = ("", 0) token: ResolveResult = { "hostname": "localhost", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } class DummyTracer(Trace): def __init__(self) -> None: """Dummy""" async def send_dns_cache_hit(self, *args: object, **kwargs: object) -> None: futures = connector._throttle_dns_futures.pop(key) for fut in futures: fut.set_result(None) connector._cached_hosts.add(key, [token]) connector = TCPConnector() connector._throttle_dns_futures[key] = set() traces = [DummyTracer()] assert await connector._resolve_host("", 0, traces) == [token] await connector.close() async def test_connector_resolve_in_case_of_trace_cache_miss_exception( loop: asyncio.AbstractEventLoop, ) -> None: token: ResolveResult = { "hostname": "localhost", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } request_count = 0 class DummyTracer(Trace): def __init__(self) -> None: """Dummy""" async def send_dns_cache_hit(self, *args: object, **kwargs: object) -> None: """Dummy send_dns_cache_hit""" async def send_dns_resolvehost_start( self, *args: object, **kwargs: object ) -> None: """Dummy send_dns_resolvehost_start""" async def send_dns_resolvehost_end( self, *args: object, **kwargs: object ) -> None: """Dummy send_dns_resolvehost_end""" async def send_dns_cache_miss(self, *args: object, **kwargs: object) -> None: nonlocal request_count request_count += 1 if request_count <= 1: raise Exception("first attempt") async def resolve_response( *_args: object, **_kwargs: object ) -> list[ResolveResult]: await asyncio.sleep(0) return [token] with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: mock_default_resolver = mock.create_autospec( AsyncResolver, instance=True, spec_set=True ) mock_default_resolver.resolve.side_effect = resolve_response m_resolver.return_value = mock_default_resolver connector = TCPConnector() traces = [DummyTracer()] with pytest.raises(Exception): await connector._resolve_host("", 0, traces) await connector._resolve_host("", 0, traces) == [token] await connector.close() async def test_connector_does_not_remove_needed_waiters( loop: asyncio.AbstractEventLoop, key: ConnectionKey, make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = make_client_request("GET", URL("https://localhost:80"), loop=loop) connection_key = req.connection_key async def await_connection_and_check_waiters() -> None: connection = await connector.connect(req, [], ClientTimeout()) try: assert connection_key in connector._waiters assert dummy_waiter in connector._waiters[connection_key] finally: connection.close() async def allow_connection_and_add_dummy_waiter() -> None: list(connector._waiters[connection_key])[0].set_result(None) del connector._waiters[connection_key] connector._waiters[connection_key][dummy_waiter] = None connector = aiohttp.BaseConnector() with mock.patch.object( connector, "_available_connections", autospec=True, spec_set=True, side_effect=[0, 1, 1, 1], ): connector._conns[key] = deque([(proto, loop.time())]) with mock.patch.object( connector, "_create_connection", autospec=True, spec_set=True, return_value=proto, ): dummy_waiter = loop.create_future() await asyncio.gather( await_connection_and_check_waiters(), allow_connection_and_add_dummy_waiter(), ) await connector.close() def test_connector_multiple_event_loop(make_client_request: _RequestMaker) -> None: """Test the connector with multiple event loops.""" async def async_connect() -> Literal[True]: conn = aiohttp.TCPConnector() loop = asyncio.get_running_loop() req = make_client_request("GET", URL("https://127.0.0.1"), loop=loop) with suppress(aiohttp.ClientConnectorError): with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True, side_effect=ssl.CertificateError, ): await conn.connect(req, [], ClientTimeout()) return True def test_connect() -> Literal[True]: loop = asyncio.new_event_loop() try: return loop.run_until_complete(async_connect()) finally: loop.close() with futures.ThreadPoolExecutor() as executor: res_list = [executor.submit(test_connect) for _ in range(2)] raw_response_list = [res.result() for res in futures.as_completed(res_list)] assert raw_response_list == [True, True] async def test_tcp_connector_socket_factory( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock, make_client_request: _RequestMaker, ) -> None: """Check that socket factory is called""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: start_connection.return_value = s local_addr = None socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s happy_eyeballs_delay = 0.123 interleave = 3 conn = aiohttp.TCPConnector( interleave=interleave, local_addr=local_addr, happy_eyeballs_delay=happy_eyeballs_delay, socket_factory=socket_factory, ) with mock.patch.object( conn._loop, "create_connection", autospec=True, spec_set=True, return_value=(mock.Mock(), mock.Mock()), ): host = "127.0.0.1" port = 443 req = make_client_request("GET", URL(f"https://{host}:{port}"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): pass await conn.close() start_connection.assert_called_with( addr_infos=[ (socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port)) ], local_addr_infos=local_addr, happy_eyeballs_delay=happy_eyeballs_delay, interleave=interleave, loop=loop, socket_factory=socket_factory, ) def test_default_ssl_context_creation_without_ssl() -> None: """Verify _make_ssl_context does not raise when ssl is not available.""" with mock.patch.object(connector_module, "ssl", None): assert connector_module._make_ssl_context(False) is None assert connector_module._make_ssl_context(True) is None def _acquired_connection( conn: aiohttp.BaseConnector, proto: ResponseHandler, key: ConnectionKey ) -> Connection: conn._acquired.add(proto) conn._acquired_per_host[key].add(proto) return Connection(conn, key, proto, conn._loop) async def test_available_connections_with_limit_per_host( key: ConnectionKey, other_host_key2: ConnectionKey ) -> None: """Verify expected values based on active connections with host limit.""" conn = aiohttp.BaseConnector(limit=3, limit_per_host=2) assert conn._available_connections(key) == 2 assert conn._available_connections(other_host_key2) == 2 proto1 = create_mocked_conn() connection1 = _acquired_connection(conn, proto1, key) assert conn._available_connections(key) == 1 assert conn._available_connections(other_host_key2) == 2 proto2 = create_mocked_conn() connection2 = _acquired_connection(conn, proto2, key) assert conn._available_connections(key) == 0 assert conn._available_connections(other_host_key2) == 1 connection1.close() assert conn._available_connections(key) == 1 assert conn._available_connections(other_host_key2) == 2 connection2.close() other_proto1 = create_mocked_conn() other_connection1 = _acquired_connection(conn, other_proto1, other_host_key2) assert conn._available_connections(key) == 2 assert conn._available_connections(other_host_key2) == 1 other_connection1.close() assert conn._available_connections(key) == 2 assert conn._available_connections(other_host_key2) == 2 @pytest.mark.parametrize("limit_per_host", [0, 10]) async def test_available_connections_without_limit_per_host( # type: ignore[misc] key: ConnectionKey, other_host_key2: ConnectionKey, limit_per_host: int ) -> None: """Verify expected values based on active connections with higher host limit.""" conn = aiohttp.BaseConnector(limit=3, limit_per_host=limit_per_host) assert conn._available_connections(key) == 3 assert conn._available_connections(other_host_key2) == 3 proto1 = create_mocked_conn() connection1 = _acquired_connection(conn, proto1, key) assert conn._available_connections(key) == 2 assert conn._available_connections(other_host_key2) == 2 proto2 = create_mocked_conn() connection2 = _acquired_connection(conn, proto2, key) assert conn._available_connections(key) == 1 assert conn._available_connections(other_host_key2) == 1 connection1.close() assert conn._available_connections(key) == 2 assert conn._available_connections(other_host_key2) == 2 connection2.close() other_proto1 = create_mocked_conn() other_connection1 = _acquired_connection(conn, other_proto1, other_host_key2) assert conn._available_connections(key) == 2 assert conn._available_connections(other_host_key2) == 2 other_connection1.close() assert conn._available_connections(key) == 3 assert conn._available_connections(other_host_key2) == 3 async def test_available_connections_no_limits( key: ConnectionKey, other_host_key2: ConnectionKey ) -> None: """Verify expected values based on active connections with no limits.""" # No limits is a special case where available connections should always be 1. conn = aiohttp.BaseConnector(limit=0, limit_per_host=0) assert conn._available_connections(key) == 1 assert conn._available_connections(other_host_key2) == 1 proto1 = create_mocked_conn() connection1 = _acquired_connection(conn, proto1, key) assert conn._available_connections(key) == 1 assert conn._available_connections(other_host_key2) == 1 connection1.close() assert conn._available_connections(key) == 1 assert conn._available_connections(other_host_key2) == 1 async def test_connect_tunnel_connection_release( loop: asyncio.AbstractEventLoop, ) -> None: """Test _ConnectTunnelConnection.release() does not pool the connection.""" connector = mock.create_autospec( aiohttp.BaseConnector, spec_set=True, instance=True ) key = mock.create_autospec(ConnectionKey, spec_set=True, instance=True) protocol = mock.create_autospec(ResponseHandler, spec_set=True, instance=True) # Create a connect tunnel connection conn = _ConnectTunnelConnection(connector, key, protocol, loop) # Verify protocol is set assert conn._protocol is protocol # Release should do nothing (not pool the connection) conn.release() # Protocol should still be there (not released to pool) assert conn._protocol is protocol # Connector._release should NOT have been called connector._release.assert_not_called() # Clean up to avoid resource warning conn.close() ================================================ FILE: tests/test_cookie_helpers.py ================================================ """Tests for internal cookie helper functions.""" import logging import sys import time from http.cookies import ( CookieError, Morsel, SimpleCookie, _unquote as simplecookie_unquote, ) import pytest from aiohttp import _cookie_helpers as helpers from aiohttp._cookie_helpers import ( _unquote, parse_cookie_header, parse_set_cookie_headers, preserve_morsel_with_coded_value, ) def test_known_attrs_is_superset_of_morsel_reserved() -> None: """Test that _COOKIE_KNOWN_ATTRS contains all Morsel._reserved attributes.""" # Get Morsel._reserved attributes (lowercase) morsel_reserved = {attr.lower() for attr in Morsel._reserved} # type: ignore[attr-defined] # _COOKIE_KNOWN_ATTRS should be a superset of morsel_reserved assert ( helpers._COOKIE_KNOWN_ATTRS >= morsel_reserved ), f"_COOKIE_KNOWN_ATTRS is missing: {morsel_reserved - helpers._COOKIE_KNOWN_ATTRS}" def test_bool_attrs_is_superset_of_morsel_flags() -> None: """Test that _COOKIE_BOOL_ATTRS contains all Morsel._flags attributes.""" # Get Morsel._flags attributes (lowercase) morsel_flags = {attr.lower() for attr in Morsel._flags} # type: ignore[attr-defined] # _COOKIE_BOOL_ATTRS should be a superset of morsel_flags assert ( helpers._COOKIE_BOOL_ATTRS >= morsel_flags ), f"_COOKIE_BOOL_ATTRS is missing: {morsel_flags - helpers._COOKIE_BOOL_ATTRS}" def test_preserve_morsel_with_coded_value() -> None: """Test preserve_morsel_with_coded_value preserves coded_value exactly.""" # Create a cookie with a coded_value different from value cookie: Morsel[str] = Morsel() cookie.set("test_cookie", "decoded value", "encoded%20value") # Preserve the coded_value result = preserve_morsel_with_coded_value(cookie) # Check that all values are preserved assert result.key == "test_cookie" assert result.value == "decoded value" assert result.coded_value == "encoded%20value" # Should be a different Morsel instance assert result is not cookie def test_preserve_morsel_with_coded_value_no_coded_value() -> None: """Test preserve_morsel_with_coded_value when coded_value is same as value.""" cookie: Morsel[str] = Morsel() cookie.set("test_cookie", "simple_value", "simple_value") result = preserve_morsel_with_coded_value(cookie) assert result.key == "test_cookie" assert result.value == "simple_value" assert result.coded_value == "simple_value" def test_parse_set_cookie_headers_simple() -> None: """Test parse_set_cookie_headers with simple cookies.""" headers = ["name=value", "session=abc123"] result = parse_set_cookie_headers(headers) assert len(result) == 2 assert result[0][0] == "name" assert result[0][1].key == "name" assert result[0][1].value == "value" assert result[1][0] == "session" assert result[1][1].key == "session" assert result[1][1].value == "abc123" def test_parse_set_cookie_headers_with_attributes() -> None: """Test parse_set_cookie_headers with cookie attributes.""" headers = [ "sessionid=value123; Path=/; HttpOnly; Secure", "user=john; Domain=.example.com; Max-Age=3600", ] result = parse_set_cookie_headers(headers) assert len(result) == 2 # First cookie name1, morsel1 = result[0] assert name1 == "sessionid" assert morsel1.value == "value123" assert morsel1["path"] == "/" assert morsel1["httponly"] is True assert morsel1["secure"] is True # Second cookie name2, morsel2 = result[1] assert name2 == "user" assert morsel2.value == "john" assert morsel2["domain"] == ".example.com" assert morsel2["max-age"] == "3600" def test_parse_set_cookie_headers_special_chars_in_names() -> None: """Test parse_set_cookie_headers accepts special characters in names (#2683).""" # These should be accepted with relaxed validation headers = [ "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=value1", "cookie[index]=value2", "cookie(param)=value3", "cookie:name=value4", "cookie@domain=value5", ] result = parse_set_cookie_headers(headers) assert len(result) == 5 expected_names = [ "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}", "cookie[index]", "cookie(param)", "cookie:name", "cookie@domain", ] for i, (name, morsel) in enumerate(result): assert name == expected_names[i] assert morsel.key == expected_names[i] assert morsel.value == f"value{i+1}" def test_parse_set_cookie_headers_invalid_names() -> None: """Test parse_set_cookie_headers rejects truly invalid cookie names.""" # These should be rejected even with relaxed validation headers = [ "invalid\tcookie=value", # Tab character "invalid\ncookie=value", # Newline "invalid\rcookie=value", # Carriage return "\x00badname=value", # Null character "name with spaces=value", # Spaces in name ] result = parse_set_cookie_headers(headers) # All should be skipped assert len(result) == 0 def test_parse_set_cookie_headers_empty_and_invalid() -> None: """Test parse_set_cookie_headers handles empty and invalid formats.""" headers = [ "", # Empty header " ", # Whitespace only "=value", # No name "name=", # Empty value (should be accepted) "justname", # No value (should be skipped) "path=/", # Reserved attribute as name (should be skipped) "Domain=.com", # Reserved attribute as name (should be skipped) ] result = parse_set_cookie_headers(headers) # Only "name=" should be accepted assert len(result) == 1 assert result[0][0] == "name" assert result[0][1].value == "" def test_parse_set_cookie_headers_quoted_values() -> None: """Test parse_set_cookie_headers handles quoted values correctly.""" headers = [ 'name="quoted value"', 'session="with;semicolon"', 'data="with\\"escaped\\""', ] result = parse_set_cookie_headers(headers) assert len(result) == 3 assert result[0][1].value == "quoted value" assert result[1][1].value == "with;semicolon" assert result[2][1].value == 'with"escaped"' @pytest.mark.parametrize( "header", [ 'session="abc;xyz"; token=123', 'data="value;with;multiple;semicolons"; next=cookie', 'complex="a=b;c=d"; simple=value', ], ) def test_parse_set_cookie_headers_semicolon_in_quoted_values(header: str) -> None: """ Test that semicolons inside properly quoted values are handled correctly. Cookie values can contain semicolons when properly quoted. This test ensures that our parser handles these cases correctly, matching SimpleCookie behavior. """ # Test with SimpleCookie sc = SimpleCookie() sc.load(header) # Test with our parser result = parse_set_cookie_headers([header]) # Should parse the same number of cookies assert len(result) == len(sc) # Verify each cookie matches SimpleCookie for (name, morsel), (sc_name, sc_morsel) in zip(result, sc.items()): assert name == sc_name assert morsel.value == sc_morsel.value def test_parse_set_cookie_headers_multiple_cookies_same_header() -> None: """Test parse_set_cookie_headers with multiple cookies in one header.""" # Note: SimpleCookie includes the comma as part of the first cookie's value headers = ["cookie1=value1, cookie2=value2"] result = parse_set_cookie_headers(headers) # Should parse as two separate cookies assert len(result) == 2 assert result[0][0] == "cookie1" assert result[0][1].value == "value1," # Comma is included in the value assert result[1][0] == "cookie2" assert result[1][1].value == "value2" @pytest.mark.parametrize( "header", [ # Standard cookies "session=abc123", "user=john; Path=/", "token=xyz; Secure; HttpOnly", # Empty values "empty=", # Quoted values 'quoted="value with spaces"', # Multiple attributes "complex=value; Domain=.example.com; Path=/app; Max-Age=3600", ], ) def test_parse_set_cookie_headers_compatibility_with_simple_cookie(header: str) -> None: """Test parse_set_cookie_headers is bug-for-bug compatible with SimpleCookie.load.""" # Parse with SimpleCookie sc = SimpleCookie() sc.load(header) # Parse with our function result = parse_set_cookie_headers([header]) # Should have same number of cookies assert len(result) == len(sc) # Compare each cookie for name, morsel in result: assert name in sc sc_morsel = sc[name] # Compare values assert morsel.value == sc_morsel.value assert morsel.key == sc_morsel.key # Compare attributes (only those that SimpleCookie would set) for attr in ["path", "domain", "max-age"]: assert morsel.get(attr) == sc_morsel.get(attr) # Boolean attributes are handled differently # SimpleCookie sets them to empty string when not present, True when present for bool_attr in ["secure", "httponly"]: # Only check if SimpleCookie has the attribute set to True if sc_morsel.get(bool_attr) is True: assert morsel.get(bool_attr) is True def test_parse_set_cookie_headers_relaxed_validation_differences() -> None: """Test where parse_set_cookie_headers differs from SimpleCookie (relaxed validation).""" # Test cookies that SimpleCookie rejects with CookieError rejected_by_simplecookie = [ ("cookie{with}braces=value1", "cookie{with}braces", "value1"), ("cookie(with)parens=value3", "cookie(with)parens", "value3"), ("cookie@with@at=value5", "cookie@with@at", "value5"), ] for header, expected_name, expected_value in rejected_by_simplecookie: # SimpleCookie should reject these with CookieError sc = SimpleCookie() with pytest.raises(CookieError): sc.load(header) # Our parser should accept them result = parse_set_cookie_headers([header]) assert len(result) == 1 # We accept assert result[0][0] == expected_name assert result[0][1].value == expected_value # Test cookies that SimpleCookie accepts (but we handle more consistently) accepted_by_simplecookie = [ ("cookie[with]brackets=value2", "cookie[with]brackets", "value2"), ("cookie:with:colons=value4", "cookie:with:colons", "value4"), ] for header, expected_name, expected_value in accepted_by_simplecookie: # SimpleCookie accepts these sc = SimpleCookie() sc.load(header) # May or may not parse correctly in SimpleCookie # Our parser should accept them consistently result = parse_set_cookie_headers([header]) assert len(result) == 1 assert result[0][0] == expected_name assert result[0][1].value == expected_value def test_parse_set_cookie_headers_case_insensitive_attrs() -> None: """Test that known attributes are handled case-insensitively.""" headers = [ "cookie1=value1; PATH=/test; DOMAIN=example.com", "cookie2=value2; Secure; HTTPONLY; max-AGE=60", ] result = parse_set_cookie_headers(headers) assert len(result) == 2 # First cookie - attributes should be recognized despite case assert result[0][1]["path"] == "/test" assert result[0][1]["domain"] == "example.com" # Second cookie assert result[1][1]["secure"] is True assert result[1][1]["httponly"] is True assert result[1][1]["max-age"] == "60" def test_parse_set_cookie_headers_unknown_attrs_ignored() -> None: """Test that unknown attributes are treated as new cookies (same as SimpleCookie).""" headers = [ "cookie=value; Path=/; unknownattr=ignored; HttpOnly", ] result = parse_set_cookie_headers(headers) # SimpleCookie treats unknown attributes with values as new cookies assert len(result) == 2 # First cookie assert result[0][0] == "cookie" assert result[0][1]["path"] == "/" assert result[0][1]["httponly"] == "" # Not set on first cookie # Second cookie (the unknown attribute) assert result[1][0] == "unknownattr" assert result[1][1].value == "ignored" assert result[1][1]["httponly"] is True # HttpOnly applies to this cookie def test_parse_set_cookie_headers_complex_real_world() -> None: """Test parse_set_cookie_headers with complex real-world examples.""" headers = [ # AWS ELB cookie "AWSELB=ABCDEF1234567890ABCDEF1234567890ABCDEF1234567890; Path=/", # Google Analytics "_ga=GA1.2.1234567890.1234567890; Domain=.example.com; Path=/; Expires=Thu, 31-Dec-2025 23:59:59 GMT", # Session with all attributes "session_id=s%3AabcXYZ123.signature123; Path=/; Secure; HttpOnly; SameSite=Strict", ] result = parse_set_cookie_headers(headers) assert len(result) == 3 # Check each cookie parsed correctly assert result[0][0] == "AWSELB" assert result[1][0] == "_ga" assert result[2][0] == "session_id" # Session cookie should have all attributes session_morsel = result[2][1] assert session_morsel["secure"] is True assert session_morsel["httponly"] is True assert session_morsel.get("samesite") == "Strict" def test_parse_set_cookie_headers_boolean_attrs() -> None: """Test that boolean attributes (secure, httponly) work correctly.""" # Test secure attribute variations headers = [ "cookie1=value1; Secure", "cookie2=value2; Secure=", "cookie3=value3; Secure=true", # Non-standard but might occur ] result = parse_set_cookie_headers(headers) assert len(result) == 3 # All should have secure=True for name, morsel in result: assert morsel.get("secure") is True, f"{name} should have secure=True" # Test httponly attribute variations headers = [ "cookie4=value4; HttpOnly", "cookie5=value5; HttpOnly=", ] result = parse_set_cookie_headers(headers) assert len(result) == 2 # All should have httponly=True for name, morsel in result: assert morsel.get("httponly") is True, f"{name} should have httponly=True" @pytest.mark.skipif( sys.version_info < (3, 14), reason="Partitioned cookies support requires Python 3.14+", ) def test_parse_set_cookie_headers_boolean_attrs_with_partitioned() -> None: """Test that boolean attributes including partitioned work correctly.""" # Test secure attribute variations secure_headers = [ "cookie1=value1; Secure", "cookie2=value2; Secure=", "cookie3=value3; Secure=true", # Non-standard but might occur ] result = parse_set_cookie_headers(secure_headers) assert len(result) == 3 for name, morsel in result: assert morsel.get("secure") is True, f"{name} should have secure=True" # Test httponly attribute variations httponly_headers = [ "cookie4=value4; HttpOnly", "cookie5=value5; HttpOnly=", ] result = parse_set_cookie_headers(httponly_headers) assert len(result) == 2 for name, morsel in result: assert morsel.get("httponly") is True, f"{name} should have httponly=True" # Test partitioned attribute variations partitioned_headers = [ "cookie6=value6; Partitioned", "cookie7=value7; Partitioned=", "cookie8=value8; Partitioned=yes", # Non-standard but might occur ] result = parse_set_cookie_headers(partitioned_headers) assert len(result) == 3 for name, morsel in result: assert morsel.get("partitioned") is True, f"{name} should have partitioned=True" def test_parse_set_cookie_headers_encoded_values() -> None: """Test that parse_set_cookie_headers preserves encoded values.""" headers = [ "encoded=hello%20world", "url=https%3A%2F%2Fexample.com%2Fpath", "special=%21%40%23%24%25%5E%26*%28%29", ] result = parse_set_cookie_headers(headers) assert len(result) == 3 # Values should be preserved as-is (not decoded) assert result[0][1].value == "hello%20world" assert result[1][1].value == "https%3A%2F%2Fexample.com%2Fpath" assert result[2][1].value == "%21%40%23%24%25%5E%26*%28%29" @pytest.mark.skipif( sys.version_info < (3, 14), reason="Partitioned cookies support requires Python 3.14+", ) def test_parse_set_cookie_headers_partitioned() -> None: """ Test that parse_set_cookie_headers handles partitioned attribute correctly. This tests the fix for issue #10380 - partitioned cookies support. The partitioned attribute is a boolean flag like secure and httponly. """ headers = [ "cookie1=value1; Partitioned", "cookie2=value2; Partitioned=", "cookie3=value3; Partitioned=true", # Non-standard but might occur "cookie4=value4; Secure; Partitioned; HttpOnly", "cookie5=value5; Domain=.example.com; Path=/; Partitioned", ] result = parse_set_cookie_headers(headers) assert len(result) == 5 # All cookies should have partitioned=True for i, (name, morsel) in enumerate(result): assert ( morsel.get("partitioned") is True ), f"Cookie {i+1} should have partitioned=True" assert name == f"cookie{i+1}" assert morsel.value == f"value{i+1}" # Cookie 4 should also have secure and httponly assert result[3][1].get("secure") is True assert result[3][1].get("httponly") is True # Cookie 5 should also have domain and path assert result[4][1].get("domain") == ".example.com" assert result[4][1].get("path") == "/" @pytest.mark.skipif( sys.version_info < (3, 14), reason="Partitioned cookies support requires Python 3.14+", ) def test_parse_set_cookie_headers_partitioned_case_insensitive() -> None: """Test that partitioned attribute is recognized case-insensitively.""" headers = [ "cookie1=value1; partitioned", # lowercase "cookie2=value2; PARTITIONED", # uppercase "cookie3=value3; Partitioned", # title case "cookie4=value4; PaRtItIoNeD", # mixed case ] result = parse_set_cookie_headers(headers) assert len(result) == 4 # All should be recognized as partitioned for i, (_, morsel) in enumerate(result): assert ( morsel.get("partitioned") is True ), f"Cookie {i+1} should have partitioned=True" def test_parse_set_cookie_headers_partitioned_not_set() -> None: """Test that cookies without partitioned attribute don't have it set.""" headers = [ "normal=value; Secure; HttpOnly", "regular=cookie; Path=/", ] result = parse_set_cookie_headers(headers) assert len(result) == 2 # Check that partitioned is not set (empty string is the default for flags in Morsel) assert result[0][1].get("partitioned", "") == "" assert result[1][1].get("partitioned", "") == "" # Tests that don't require partitioned support in SimpleCookie @pytest.mark.skipif( sys.version_info >= (3, 14), reason="Python 3.14+ has built-in partitioned cookie support", ) def test_parse_set_cookie_headers_partitioned_not_set_if_no_support() -> None: headers = [ "cookie1=value1; Partitioned", "cookie2=value2; Partitioned=", "cookie3=value3; Partitioned=true", ] result = parse_set_cookie_headers(headers) assert len(result) == 3 for i, (_, morsel) in enumerate(result): assert ( morsel.get("partitioned") is None ), f"Cookie {i+1} should not have partitioned flag" def test_parse_set_cookie_headers_partitioned_with_other_attrs_manual() -> None: """ Test parsing logic for partitioned cookies combined with all other attributes. This test verifies our parsing logic handles partitioned correctly as a boolean attribute regardless of SimpleCookie support. """ # Test that our parser recognizes partitioned in _COOKIE_KNOWN_ATTRS and _COOKIE_BOOL_ATTRS assert "partitioned" in helpers._COOKIE_KNOWN_ATTRS assert "partitioned" in helpers._COOKIE_BOOL_ATTRS # Test a simple case that won't trigger SimpleCookie errors headers = ["session=abc123; Secure; HttpOnly"] result = parse_set_cookie_headers(headers) assert len(result) == 1 assert result[0][0] == "session" assert result[0][1]["secure"] is True assert result[0][1]["httponly"] is True def test_cookie_helpers_constants_include_partitioned() -> None: """Test that cookie helper constants include partitioned attribute.""" # Test our constants include partitioned assert "partitioned" in helpers._COOKIE_KNOWN_ATTRS assert "partitioned" in helpers._COOKIE_BOOL_ATTRS @pytest.mark.parametrize( "test_string", [ " Partitioned ", " partitioned ", " PARTITIONED ", " Partitioned; ", " Partitioned= ", " Partitioned=true ", ], ) def test_cookie_pattern_matches_partitioned_attribute(test_string: str) -> None: """Test that the cookie pattern regex matches various partitioned attribute formats.""" pattern = helpers._COOKIE_PATTERN match = pattern.match(test_string) assert match is not None, f"Pattern should match '{test_string}'" assert match.group("key").lower() == "partitioned" def test_cookie_pattern_performance() -> None: """Test that the cookie pattern doesn't suffer from ReDoS issues.""" COOKIE_PATTERN_TIME_THRESHOLD_SECONDS = 0.08 value = "a" + "=" * 21651 + "\x00" start = time.perf_counter() match = helpers._COOKIE_PATTERN.match(value) elapsed = time.perf_counter() - start # If this is taking more time, there's probably a performance/ReDoS issue. assert elapsed < COOKIE_PATTERN_TIME_THRESHOLD_SECONDS, ( f"Pattern took {elapsed * 1000:.1f}ms, " f"expected <{COOKIE_PATTERN_TIME_THRESHOLD_SECONDS * 1000:.0f}ms - potential ReDoS issue" ) # This example shouldn't produce a match either. assert match is None def test_parse_set_cookie_headers_issue_7993_double_quotes() -> None: """ Test that cookies with unmatched opening quotes don't break parsing of subsequent cookies. This reproduces issue #7993 where a cookie containing an unmatched opening double quote causes subsequent cookies to be silently dropped. NOTE: This only fixes the specific case where a value starts with a quote but doesn't end with one (e.g., 'cookie="value'). Other malformed quote cases still behave like SimpleCookie for compatibility. """ # Test case from the issue headers = ['foo=bar; baz="qux; foo2=bar2'] result = parse_set_cookie_headers(headers) # Should parse all cookies correctly assert len(result) == 3 assert result[0][0] == "foo" assert result[0][1].value == "bar" assert result[1][0] == "baz" assert result[1][1].value == '"qux' # Unmatched quote included assert result[2][0] == "foo2" assert result[2][1].value == "bar2" def test_parse_set_cookie_headers_empty_headers() -> None: """Test handling of empty headers in the sequence.""" # Empty header should be skipped result = parse_set_cookie_headers(["", "name=value"]) assert len(result) == 1 assert result[0][0] == "name" assert result[0][1].value == "value" # Multiple empty headers result = parse_set_cookie_headers(["", "", ""]) assert result == [] # Empty headers mixed with valid cookies result = parse_set_cookie_headers(["", "a=1", "", "b=2", ""]) assert len(result) == 2 assert result[0][0] == "a" assert result[1][0] == "b" def test_parse_set_cookie_headers_invalid_cookie_syntax() -> None: """Test handling of invalid cookie syntax.""" # No valid cookie pattern result = parse_set_cookie_headers(["@#$%^&*()"]) assert result == [] # Cookie name without value result = parse_set_cookie_headers(["name"]) assert result == [] # Multiple invalid patterns result = parse_set_cookie_headers(["!!!!", "????", "name", "@@@"]) assert result == [] def test_parse_set_cookie_headers_illegal_cookie_names( caplog: pytest.LogCaptureFixture, ) -> None: """ Test that illegal cookie names are rejected. Note: When a known attribute name is used as a cookie name at the start, parsing stops early (before any warning can be logged). Warnings are only logged when illegal names appear after a valid cookie. """ # Cookie name that is a known attribute (illegal) - parsing stops early result = parse_set_cookie_headers(["path=value; domain=test"]) assert result == [] # Cookie name that doesn't match the pattern result = parse_set_cookie_headers(["=value"]) assert result == [] # Valid cookie after illegal one - parsing stops at illegal result = parse_set_cookie_headers(["domain=bad; good=value"]) assert result == [] # Illegal cookie name that appears after a valid cookie triggers warning result = parse_set_cookie_headers(["good=value; Path=/; invalid,cookie=value;"]) assert len(result) == 1 assert result[0][0] == "good" assert "Illegal cookie name 'invalid,cookie'" in caplog.text def test_parse_set_cookie_headers_attributes_before_cookie() -> None: """Test that attributes before any cookie are invalid.""" # Path attribute before cookie result = parse_set_cookie_headers(["Path=/; name=value"]) assert result == [] # Domain attribute before cookie result = parse_set_cookie_headers(["Domain=.example.com; name=value"]) assert result == [] # Multiple attributes before cookie result = parse_set_cookie_headers( ["Path=/; Domain=.example.com; Secure; name=value"] ) assert result == [] def test_parse_set_cookie_headers_attributes_without_values() -> None: """Test handling of attributes with missing values.""" # Boolean attribute without value (valid) result = parse_set_cookie_headers(["name=value; Secure"]) assert len(result) == 1 assert result[0][1]["secure"] is True # Non-boolean attribute without value (invalid, stops parsing) result = parse_set_cookie_headers(["name=value; Path"]) assert len(result) == 1 # Path without value stops further attribute parsing # Multiple cookies, invalid attribute in middle result = parse_set_cookie_headers(["name=value; Path; Secure"]) assert len(result) == 1 # Secure is not parsed because Path without value stops parsing def test_parse_set_cookie_headers_dollar_prefixed_names() -> None: """Test handling of cookie names starting with $.""" # $Version without preceding cookie (ignored) result = parse_set_cookie_headers(["$Version=1; name=value"]) assert len(result) == 1 assert result[0][0] == "name" # Multiple $ prefixed without cookie (all ignored) result = parse_set_cookie_headers(["$Version=1; $Path=/; $Domain=.com; name=value"]) assert len(result) == 1 assert result[0][0] == "name" # $ prefix at start is ignored, cookie follows result = parse_set_cookie_headers(["$Unknown=123; valid=cookie"]) assert len(result) == 1 assert result[0][0] == "valid" def test_parse_set_cookie_headers_dollar_attributes() -> None: """Test handling of $ prefixed attributes after cookies.""" # Test multiple $ attributes with cookie (case-insensitive like SimpleCookie) result = parse_set_cookie_headers(["name=value; $Path=/test; $Domain=.example.com"]) assert len(result) == 1 assert result[0][0] == "name" assert result[0][1]["path"] == "/test" assert result[0][1]["domain"] == ".example.com" # Test unknown $ attribute (should be ignored) result = parse_set_cookie_headers(["name=value; $Unknown=test"]) assert len(result) == 1 assert result[0][0] == "name" # $Unknown should not be set # Test $ attribute with empty value result = parse_set_cookie_headers(["name=value; $Path="]) assert len(result) == 1 assert result[0][1]["path"] == "" # Test case sensitivity compatibility with SimpleCookie result = parse_set_cookie_headers(["test=value; $path=/lower; $PATH=/upper"]) assert len(result) == 1 # Last one wins, and it's case-insensitive assert result[0][1]["path"] == "/upper" def test_parse_set_cookie_headers_attributes_after_illegal_cookie() -> None: """ Test that attributes after an illegal cookie name are handled correctly. This covers the branches where current_morsel is None because an illegal cookie name was encountered. """ # Illegal cookie followed by $ attribute result = parse_set_cookie_headers(["good=value; invalid,cookie=bad; $Path=/test"]) assert len(result) == 1 assert result[0][0] == "good" # $Path should be ignored since current_morsel is None after illegal cookie # Illegal cookie followed by boolean attribute result = parse_set_cookie_headers(["good=value; invalid,cookie=bad; HttpOnly"]) assert len(result) == 1 assert result[0][0] == "good" # HttpOnly should be ignored since current_morsel is None # Illegal cookie followed by regular attribute with value result = parse_set_cookie_headers(["good=value; invalid,cookie=bad; Max-Age=3600"]) assert len(result) == 1 assert result[0][0] == "good" # Max-Age should be ignored since current_morsel is None # Multiple attributes after illegal cookie result = parse_set_cookie_headers( ["good=value; invalid,cookie=bad; $Path=/; HttpOnly; Max-Age=60; Domain=.com"] ) assert len(result) == 1 assert result[0][0] == "good" # All attributes should be ignored after illegal cookie def test_parse_set_cookie_headers_unmatched_quotes_compatibility() -> None: """ Test that most unmatched quote scenarios behave like SimpleCookie. For compatibility, we only handle the specific case of unmatched opening quotes (e.g., 'cookie="value'). Other cases behave the same as SimpleCookie. """ # Cases that SimpleCookie and our parser both fail to parse completely incompatible_cases = [ 'cookie1=val"ue; cookie2=value2', # codespell:ignore 'cookie1=value"; cookie2=value2', 'cookie1=va"l"ue"; cookie2=value2', # codespell:ignore 'cookie1=value1; cookie2=val"ue; cookie3=value3', # codespell:ignore ] for header in incompatible_cases: # Test SimpleCookie behavior sc = SimpleCookie() sc.load(header) sc_cookies = list(sc.items()) # Test our parser behavior result = parse_set_cookie_headers([header]) # Both should parse the same cookies (partial parsing) assert len(result) == len(sc_cookies), ( f"Header: {header}\n" f"SimpleCookie parsed: {len(sc_cookies)} cookies\n" f"Our parser parsed: {len(result)} cookies" ) # The case we specifically fix (unmatched opening quote) fixed_case = 'cookie1=value1; cookie2="unmatched; cookie3=value3' # SimpleCookie fails to parse cookie3 sc = SimpleCookie() sc.load(fixed_case) assert len(sc) == 1 # Only cookie1 # Our parser handles it better result = parse_set_cookie_headers([fixed_case]) assert len(result) == 3 # All three cookies assert result[0][0] == "cookie1" assert result[0][1].value == "value1" assert result[1][0] == "cookie2" assert result[1][1].value == '"unmatched' assert result[2][0] == "cookie3" assert result[2][1].value == "value3" def test_parse_set_cookie_headers_expires_attribute() -> None: """Test parse_set_cookie_headers handles expires attribute with date formats.""" headers = [ "session=abc; Expires=Wed, 09 Jun 2021 10:18:14 GMT", "user=xyz; expires=Wednesday, 09-Jun-21 10:18:14 GMT", "token=123; EXPIRES=Wed, 09 Jun 2021 10:18:14 GMT", ] result = parse_set_cookie_headers(headers) assert len(result) == 3 for _, morsel in result: assert "expires" in morsel assert "GMT" in morsel["expires"] def test_parse_set_cookie_headers_edge_cases() -> None: """Test various edge cases.""" # Very long cookie values long_value = "x" * 4096 result = parse_set_cookie_headers([f"name={long_value}"]) assert len(result) == 1 assert result[0][1].value == long_value def test_parse_set_cookie_headers_various_date_formats_issue_4327() -> None: """ Test that parse_set_cookie_headers handles various date formats per RFC 6265. This tests the fix for issue #4327 - support for RFC 822, RFC 850, and ANSI C asctime() date formats in cookie expiration. """ # Test various date formats headers = [ # RFC 822 format (preferred format) "cookie1=value1; Expires=Wed, 09 Jun 2021 10:18:14 GMT", # RFC 850 format (obsolete but still used) "cookie2=value2; Expires=Wednesday, 09-Jun-21 10:18:14 GMT", # RFC 822 with dashes "cookie3=value3; Expires=Wed, 09-Jun-2021 10:18:14 GMT", # ANSI C asctime() format (aiohttp extension - not supported by SimpleCookie) "cookie4=value4; Expires=Wed Jun 9 10:18:14 2021", # Various other formats seen in the wild "cookie5=value5; Expires=Thu, 01 Jan 2030 00:00:00 GMT", "cookie6=value6; Expires=Mon, 31-Dec-99 23:59:59 GMT", "cookie7=value7; Expires=Tue, 01-Jan-30 00:00:00 GMT", ] result = parse_set_cookie_headers(headers) # All cookies should be parsed assert len(result) == 7 # Check each cookie was parsed with its expires attribute expected_cookies = [ ("cookie1", "value1", "Wed, 09 Jun 2021 10:18:14 GMT"), ("cookie2", "value2", "Wednesday, 09-Jun-21 10:18:14 GMT"), ("cookie3", "value3", "Wed, 09-Jun-2021 10:18:14 GMT"), ("cookie4", "value4", "Wed Jun 9 10:18:14 2021"), ("cookie5", "value5", "Thu, 01 Jan 2030 00:00:00 GMT"), ("cookie6", "value6", "Mon, 31-Dec-99 23:59:59 GMT"), ("cookie7", "value7", "Tue, 01-Jan-30 00:00:00 GMT"), ] for (name, morsel), (exp_name, exp_value, exp_expires) in zip( result, expected_cookies ): assert name == exp_name assert morsel.value == exp_value assert morsel.get("expires") == exp_expires def test_parse_set_cookie_headers_ansi_c_asctime_format() -> None: """ Test parsing of ANSI C asctime() format. This tests support for ANSI C asctime() format (e.g., "Wed Jun 9 10:18:14 2021"). NOTE: This is an aiohttp extension - SimpleCookie does NOT support this format. """ headers = ["cookie1=value1; Expires=Wed Jun 9 10:18:14 2021"] result = parse_set_cookie_headers(headers) # Should parse correctly with the expires attribute preserved assert len(result) == 1 assert result[0][0] == "cookie1" assert result[0][1].value == "value1" assert result[0][1]["expires"] == "Wed Jun 9 10:18:14 2021" def test_parse_set_cookie_headers_rfc2822_timezone_issue_4493() -> None: """ Test that parse_set_cookie_headers handles RFC 2822 timezone formats. This tests the fix for issue #4493 - support for RFC 2822-compliant dates with timezone offsets like -0000, +0100, etc. NOTE: This is an aiohttp extension - SimpleCookie does NOT support this format. """ headers = [ # RFC 2822 with -0000 timezone (common in some APIs) "hello=world; expires=Wed, 15 Jan 2020 09:45:07 -0000", # RFC 2822 with positive offset "session=abc123; expires=Thu, 01 Feb 2024 14:30:00 +0100", # RFC 2822 with negative offset "token=xyz789; expires=Fri, 02 Mar 2025 08:15:30 -0500", # Standard GMT for comparison "classic=cookie; expires=Sat, 03 Apr 2026 12:00:00 GMT", ] result = parse_set_cookie_headers(headers) # All cookies should be parsed assert len(result) == 4 # Check each cookie was parsed with its expires attribute assert result[0][0] == "hello" assert result[0][1].value == "world" assert result[0][1]["expires"] == "Wed, 15 Jan 2020 09:45:07 -0000" assert result[1][0] == "session" assert result[1][1].value == "abc123" assert result[1][1]["expires"] == "Thu, 01 Feb 2024 14:30:00 +0100" assert result[2][0] == "token" assert result[2][1].value == "xyz789" assert result[2][1]["expires"] == "Fri, 02 Mar 2025 08:15:30 -0500" assert result[3][0] == "classic" assert result[3][1].value == "cookie" assert result[3][1]["expires"] == "Sat, 03 Apr 2026 12:00:00 GMT" def test_parse_set_cookie_headers_rfc2822_with_attributes() -> None: """Test that RFC 2822 dates work correctly with other cookie attributes.""" headers = [ "session=abc123; expires=Wed, 15 Jan 2020 09:45:07 -0000; Path=/; HttpOnly; Secure", "token=xyz789; expires=Thu, 01 Feb 2024 14:30:00 +0100; Domain=.example.com; SameSite=Strict", ] result = parse_set_cookie_headers(headers) assert len(result) == 2 # First cookie assert result[0][0] == "session" assert result[0][1].value == "abc123" assert result[0][1]["expires"] == "Wed, 15 Jan 2020 09:45:07 -0000" assert result[0][1]["path"] == "/" assert result[0][1]["httponly"] is True assert result[0][1]["secure"] is True # Second cookie assert result[1][0] == "token" assert result[1][1].value == "xyz789" assert result[1][1]["expires"] == "Thu, 01 Feb 2024 14:30:00 +0100" assert result[1][1]["domain"] == ".example.com" assert result[1][1]["samesite"] == "Strict" def test_parse_set_cookie_headers_date_formats_with_attributes() -> None: """Test that date formats work correctly with other cookie attributes.""" headers = [ "session=abc123; Expires=Wed, 09 Jun 2030 10:18:14 GMT; Path=/; HttpOnly; Secure", "token=xyz789; Expires=Wednesday, 09-Jun-30 10:18:14 GMT; Domain=.example.com; SameSite=Strict", ] result = parse_set_cookie_headers(headers) assert len(result) == 2 # First cookie assert result[0][0] == "session" assert result[0][1].value == "abc123" assert result[0][1]["expires"] == "Wed, 09 Jun 2030 10:18:14 GMT" assert result[0][1]["path"] == "/" assert result[0][1]["httponly"] is True assert result[0][1]["secure"] is True # Second cookie assert result[1][0] == "token" assert result[1][1].value == "xyz789" assert result[1][1]["expires"] == "Wednesday, 09-Jun-30 10:18:14 GMT" assert result[1][1]["domain"] == ".example.com" assert result[1][1]["samesite"] == "Strict" @pytest.mark.parametrize( ("header", "expected_name", "expected_value", "expected_coded"), [ # Test cookie values with octal escape sequences (r'name="\012newline\012"', "name", "\nnewline\n", r'"\012newline\012"'), ( r'tab="\011separated\011values"', "tab", "\tseparated\tvalues", r'"\011separated\011values"', ), ( r'mixed="hello\040world\041"', "mixed", "hello world!", r'"hello\040world\041"', ), ( r'complex="\042quoted\042 text with \012 newline"', "complex", '"quoted" text with \n newline', r'"\042quoted\042 text with \012 newline"', ), ], ) def test_parse_set_cookie_headers_uses_unquote_with_octal( header: str, expected_name: str, expected_value: str, expected_coded: str ) -> None: """Test that parse_set_cookie_headers correctly unquotes values with octal sequences and preserves coded_value.""" result = parse_set_cookie_headers([header]) assert len(result) == 1 name, morsel = result[0] # Check that octal sequences were properly decoded in the value assert name == expected_name assert morsel.value == expected_value # Check that coded_value preserves the original quoted string assert morsel.coded_value == expected_coded # Tests for parse_cookie_header (RFC 6265 compliant Cookie header parser) def test_parse_cookie_header_simple() -> None: """Test parse_cookie_header with simple cookies.""" header = "name=value; session=abc123" result = parse_cookie_header(header) assert len(result) == 2 assert result[0][0] == "name" assert result[0][1].value == "value" assert result[1][0] == "session" assert result[1][1].value == "abc123" def test_parse_cookie_header_empty() -> None: """Test parse_cookie_header with empty header.""" assert parse_cookie_header("") == [] assert parse_cookie_header(" ") == [] def test_parse_cookie_gstate_header() -> None: header = ( "_ga=ga; " "ajs_anonymous_id=0anonymous; " "analytics_session_id=session; " "cookies-analytics=true; " "cookies-functional=true; " "cookies-marketing=true; " "cookies-preferences=true; " 'g_state={"i_l":0,"i_ll":12345,"i_b":"blah"}; ' "analytics_session_id.last_access=1760128947692; " "landingPageURLRaw=landingPageURLRaw; " "landingPageURL=landingPageURL; " "referrerPageURLRaw=; " "referrerPageURL=; " "formURLRaw=formURLRaw; " "formURL=formURL; " "fbnAuthExpressCheckout=fbnAuthExpressCheckout; " "is_express_checkout=1; " ) result = parse_cookie_header(header) assert result[7][0] == "g_state" assert result[8][0] == "analytics_session_id.last_access" def test_parse_cookie_header_quoted_values() -> None: """Test parse_cookie_header handles quoted values correctly.""" header = 'name="quoted value"; session="with;semicolon"; data="with\\"escaped\\""' result = parse_cookie_header(header) assert len(result) == 3 assert result[0][0] == "name" assert result[0][1].value == "quoted value" assert result[1][0] == "session" assert result[1][1].value == "with;semicolon" assert result[2][0] == "data" assert result[2][1].value == 'with"escaped"' def test_parse_cookie_header_special_chars() -> None: """Test parse_cookie_header accepts special characters in names.""" header = ( "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=value1; cookie[index]=value2" ) result = parse_cookie_header(header) assert len(result) == 2 assert result[0][0] == "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}" assert result[0][1].value == "value1" assert result[1][0] == "cookie[index]" assert result[1][1].value == "value2" def test_parse_cookie_header_invalid_names() -> None: """Test parse_cookie_header rejects invalid cookie names.""" # Invalid names with control characters header = "invalid\tcookie=value; valid=cookie; invalid\ncookie=bad" result = parse_cookie_header(header) # Parse_cookie_header uses same regex as parse_set_cookie_headers # Tab and newline are treated as separators, not part of names assert len(result) == 5 assert result[0][0] == "invalid" assert result[0][1].value == "" assert result[1][0] == "cookie" assert result[1][1].value == "value" assert result[2][0] == "valid" assert result[2][1].value == "cookie" assert result[3][0] == "invalid" assert result[3][1].value == "" assert result[4][0] == "cookie" assert result[4][1].value == "bad" def test_parse_cookie_header_no_attributes() -> None: """Test parse_cookie_header treats all pairs as cookies (no attributes).""" # In Cookie headers, even reserved attribute names are treated as cookies header = ( "session=abc123; path=/test; domain=.example.com; secure=yes; httponly=true" ) result = parse_cookie_header(header) assert len(result) == 5 assert result[0][0] == "session" assert result[0][1].value == "abc123" assert result[1][0] == "path" assert result[1][1].value == "/test" assert result[2][0] == "domain" assert result[2][1].value == ".example.com" assert result[3][0] == "secure" assert result[3][1].value == "yes" assert result[4][0] == "httponly" assert result[4][1].value == "true" def test_parse_cookie_header_empty_value() -> None: """Test parse_cookie_header with empty cookie values.""" header = "empty=; name=value; also_empty=" result = parse_cookie_header(header) assert len(result) == 3 assert result[0][0] == "empty" assert result[0][1].value == "" assert result[1][0] == "name" assert result[1][1].value == "value" assert result[2][0] == "also_empty" assert result[2][1].value == "" def test_parse_cookie_header_spaces() -> None: """Test parse_cookie_header handles spaces correctly.""" header = "name1=value1 ; name2=value2 ; name3=value3" result = parse_cookie_header(header) assert len(result) == 3 assert result[0][0] == "name1" assert result[0][1].value == "value1" assert result[1][0] == "name2" assert result[1][1].value == "value2" assert result[2][0] == "name3" assert result[2][1].value == "value3" def test_parse_cookie_header_encoded_values() -> None: """Test parse_cookie_header preserves encoded values.""" header = "encoded=hello%20world; url=https%3A%2F%2Fexample.com" result = parse_cookie_header(header) assert len(result) == 2 assert result[0][0] == "encoded" assert result[0][1].value == "hello%20world" assert result[1][0] == "url" assert result[1][1].value == "https%3A%2F%2Fexample.com" def test_parse_cookie_header_malformed() -> None: """Test parse_cookie_header handles malformed input.""" # Missing value header = "name1=value1; justname; name2=value2" result = parse_cookie_header(header) # Parser accepts cookies without values (empty value) assert len(result) == 3 assert result[0][0] == "name1" assert result[0][1].value == "value1" assert result[1][0] == "justname" assert result[1][1].value == "" assert result[2][0] == "name2" assert result[2][1].value == "value2" # Missing name header = "=value; name=value2" result = parse_cookie_header(header) assert len(result) == 1 assert result[0][0] == "name" assert result[0][1].value == "value2" def test_parse_cookie_header_complex_quoted() -> None: """Test parse_cookie_header with complex quoted values.""" header = 'session="abc;xyz"; data="value;with;multiple;semicolons"; simple=unquoted' result = parse_cookie_header(header) assert len(result) == 3 assert result[0][0] == "session" assert result[0][1].value == "abc;xyz" assert result[1][0] == "data" assert result[1][1].value == "value;with;multiple;semicolons" assert result[2][0] == "simple" assert result[2][1].value == "unquoted" def test_parse_cookie_header_unmatched_quotes() -> None: """Test parse_cookie_header handles unmatched quotes.""" header = 'cookie1=value1; cookie2="unmatched; cookie3=value3' result = parse_cookie_header(header) # Should parse all cookies correctly assert len(result) == 3 assert result[0][0] == "cookie1" assert result[0][1].value == "value1" assert result[1][0] == "cookie2" assert result[1][1].value == '"unmatched' assert result[2][0] == "cookie3" assert result[2][1].value == "value3" def test_parse_cookie_header_vs_parse_set_cookie_headers() -> None: """Test difference between parse_cookie_header and parse_set_cookie_headers.""" # Cookie header with attribute-like pairs cookie_header = "session=abc123; path=/test; secure=yes" # parse_cookie_header treats all as cookies cookie_result = parse_cookie_header(cookie_header) assert len(cookie_result) == 3 assert cookie_result[0][0] == "session" assert cookie_result[0][1].value == "abc123" assert cookie_result[1][0] == "path" assert cookie_result[1][1].value == "/test" assert cookie_result[2][0] == "secure" assert cookie_result[2][1].value == "yes" # parse_set_cookie_headers would treat path and secure as attributes set_cookie_result = parse_set_cookie_headers([cookie_header]) assert len(set_cookie_result) == 1 assert set_cookie_result[0][0] == "session" assert set_cookie_result[0][1].value == "abc123" assert set_cookie_result[0][1]["path"] == "/test" # secure with any value is treated as boolean True assert set_cookie_result[0][1]["secure"] is True def test_parse_cookie_header_compatibility_with_simple_cookie() -> None: """Test parse_cookie_header output works with SimpleCookie.""" header = "session=abc123; user=john; token=xyz789" # Parse with our function parsed = parse_cookie_header(header) # Create SimpleCookie and update with our results sc = SimpleCookie() sc.update(parsed) # Verify all cookies are present assert len(sc) == 3 assert sc["session"].value == "abc123" assert sc["user"].value == "john" assert sc["token"].value == "xyz789" def test_parse_cookie_header_real_world_examples() -> None: """Test parse_cookie_header with real-world Cookie headers.""" # Google Analytics style header = "_ga=GA1.2.1234567890.1234567890; _gid=GA1.2.0987654321.0987654321" result = parse_cookie_header(header) assert len(result) == 2 assert result[0][0] == "_ga" assert result[0][1].value == "GA1.2.1234567890.1234567890" assert result[1][0] == "_gid" assert result[1][1].value == "GA1.2.0987654321.0987654321" # Session cookies header = "PHPSESSID=abc123def456; csrf_token=xyz789; logged_in=true" result = parse_cookie_header(header) assert len(result) == 3 assert result[0][0] == "PHPSESSID" assert result[0][1].value == "abc123def456" assert result[1][0] == "csrf_token" assert result[1][1].value == "xyz789" assert result[2][0] == "logged_in" assert result[2][1].value == "true" # Complex values with proper quoting header = r'preferences="{\"theme\":\"dark\",\"lang\":\"en\"}"; session_data=eyJhbGciOiJIUzI1NiJ9' result = parse_cookie_header(header) assert len(result) == 2 assert result[0][0] == "preferences" assert result[0][1].value == '{"theme":"dark","lang":"en"}' assert result[1][0] == "session_data" assert result[1][1].value == "eyJhbGciOiJIUzI1NiJ9" def test_parse_cookie_header_issue_7993() -> None: """Test parse_cookie_header handles issue #7993 correctly.""" # This specific case from issue #7993 header = 'foo=bar; baz="qux; foo2=bar2' result = parse_cookie_header(header) # All cookies should be parsed assert len(result) == 3 assert result[0][0] == "foo" assert result[0][1].value == "bar" assert result[1][0] == "baz" assert result[1][1].value == '"qux' assert result[2][0] == "foo2" assert result[2][1].value == "bar2" def test_parse_cookie_header_illegal_names(caplog: pytest.LogCaptureFixture) -> None: """Test parse_cookie_header warns about illegal cookie names.""" # Cookie name with comma (not allowed in _COOKIE_NAME_RE) header = "good=value; invalid,cookie=bad; another=test" with caplog.at_level(logging.DEBUG): result = parse_cookie_header(header) # Should skip the invalid cookie but continue parsing assert len(result) == 2 assert result[0][0] == "good" assert result[0][1].value == "value" assert result[1][0] == "another" assert result[1][1].value == "test" assert "Cannot load cookie. Illegal cookie name" in caplog.text assert "'invalid,cookie'" in caplog.text def test_parse_cookie_header_large_value() -> None: """Test that large cookie values don't cause DoS.""" large_value = "A" * 8192 header = f"normal=value; large={large_value}; after=cookie" result = parse_cookie_header(header) cookie_names = [name for name, _ in result] assert len(result) == 3 assert "normal" in cookie_names assert "large" in cookie_names assert "after" in cookie_names large_cookie = next(morsel for name, morsel in result if name == "large") assert len(large_cookie.value) == 8192 def test_parse_cookie_header_multiple_equals() -> None: """Test handling of multiple equals signs in cookie values.""" header = "session=abc123; data=key1=val1&key2=val2; token=xyz" result = parse_cookie_header(header) assert len(result) == 3 name1, morsel1 = result[0] assert name1 == "session" assert morsel1.value == "abc123" name2, morsel2 = result[1] assert name2 == "data" assert morsel2.value == "key1=val1&key2=val2" name3, morsel3 = result[2] assert name3 == "token" assert morsel3.value == "xyz" def test_parse_cookie_header_fallback_preserves_subsequent_cookies() -> None: """Test that fallback parser doesn't lose subsequent cookies.""" header = 'normal=value; malformed={"json":"value"}; after1=cookie1; after2=cookie2' result = parse_cookie_header(header) cookie_names = [name for name, _ in result] assert len(result) == 4 assert cookie_names == ["normal", "malformed", "after1", "after2"] name1, morsel1 = result[0] assert morsel1.value == "value" name2, morsel2 = result[1] assert morsel2.value == '{"json":"value"}' name3, morsel3 = result[2] assert morsel3.value == "cookie1" name4, morsel4 = result[3] assert morsel4.value == "cookie2" def test_parse_cookie_header_whitespace_in_fallback() -> None: """Test that fallback parser handles whitespace correctly.""" header = "a=1; b = 2 ; c= 3; d =4" result = parse_cookie_header(header) assert len(result) == 4 for name, morsel in result: assert name in ("a", "b", "c", "d") assert morsel.value in ("1", "2", "3", "4") def test_parse_cookie_header_empty_value_in_fallback() -> None: """Test that fallback handles empty values correctly.""" header = "normal=value; empty=; another=test" result = parse_cookie_header(header) assert len(result) == 3 name1, morsel1 = result[0] assert name1 == "normal" assert morsel1.value == "value" name2, morsel2 = result[1] assert name2 == "empty" assert morsel2.value == "" name3, morsel3 = result[2] assert name3 == "another" assert morsel3.value == "test" def test_parse_cookie_header_invalid_name_in_fallback( caplog: pytest.LogCaptureFixture, ) -> None: """Test that fallback parser rejects cookies with invalid names.""" header = 'normal=value; invalid,name={"x":"y"}; another=test' with caplog.at_level(logging.DEBUG): result = parse_cookie_header(header) assert len(result) == 2 name1, morsel1 = result[0] assert name1 == "normal" assert morsel1.value == "value" name2, morsel2 = result[1] assert name2 == "another" assert morsel2.value == "test" assert "Cannot load cookie. Illegal cookie name" in caplog.text assert "'invalid,name'" in caplog.text def test_parse_cookie_header_empty_key_in_fallback( caplog: pytest.LogCaptureFixture, ) -> None: """Test that fallback parser logs warning for empty cookie names.""" header = 'normal=value; ={"malformed":"json"}; another=test' with caplog.at_level(logging.DEBUG): result = parse_cookie_header(header) assert len(result) == 2 name1, morsel1 = result[0] assert name1 == "normal" assert morsel1.value == "value" name2, morsel2 = result[1] assert name2 == "another" assert morsel2.value == "test" assert "Cannot load cookie. Illegal cookie name" in caplog.text assert "''" in caplog.text @pytest.mark.parametrize( ("input_str", "expected"), [ # Unquoted strings should remain unchanged ("simple", "simple"), ("with spaces", "with spaces"), ("", ""), ('"', '"'), # String too short to be quoted ('some"text', 'some"text'), # Quotes not at beginning/end ('text"with"quotes', 'text"with"quotes'), ], ) def test_unquote_basic(input_str: str, expected: str) -> None: """Test basic _unquote functionality.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( ("input_str", "expected"), [ # Basic quoted strings ('"quoted"', "quoted"), ('"with spaces"', "with spaces"), ('""', ""), # Empty quoted string # Quoted string with special characters ('"hello, world!"', "hello, world!"), ('"path=/test"', "path=/test"), ], ) def test_unquote_quoted_strings(input_str: str, expected: str) -> None: """Test _unquote with quoted strings.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( ("input_str", "expected"), [ # Escaped quotes should be unescaped (r'"say \"hello\""', 'say "hello"'), (r'"nested \"quotes\" here"', 'nested "quotes" here'), # Multiple escaped quotes (r'"\"start\" middle \"end\""', '"start" middle "end"'), ], ) def test_unquote_escaped_quotes(input_str: str, expected: str) -> None: """Test _unquote with escaped quotes.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( ("input_str", "expected"), [ # Single escaped backslash (r'"path\\to\\file"', "path\\to\\file"), # Backslash before quote (r'"end with slash\\"', "end with slash\\"), # Mixed escaped characters (r'"path\\to\\\"file\""', 'path\\to\\"file"'), ], ) def test_unquote_escaped_backslashes(input_str: str, expected: str) -> None: """Test _unquote with escaped backslashes.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( ("input_str", "expected"), [ # Common octal sequences (r'"\012"', "\n"), # newline (r'"\011"', "\t"), # tab (r'"\015"', "\r"), # carriage return (r'"\040"', " "), # space # Octal sequences in context (r'"line1\012line2"', "line1\nline2"), (r'"tab\011separated"', "tab\tseparated"), # Multiple octal sequences (r'"\012\011\015"', "\n\t\r"), # Mixed octal and regular text (r'"hello\040world\041"', "hello world!"), ], ) def test_unquote_octal_sequences(input_str: str, expected: str) -> None: """Test _unquote with octal escape sequences.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( ("input_str", "expected"), [ # Test boundary values (r'"\000"', "\x00"), # null character (r'"\001"', "\x01"), (r'"\177"', "\x7f"), # DEL character (r'"\200"', "\x80"), # Extended ASCII (r'"\377"', "\xff"), # Max octal value # Invalid octal sequences (not 3 digits or > 377) are treated as regular escapes (r'"\400"', "400"), # 400 octal = 256 decimal, too large (r'"\777"', "777"), # 777 octal = 511 decimal, too large ], ) def test_unquote_octal_full_range(input_str: str, expected: str) -> None: """Test _unquote with full range of valid octal sequences.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( ("input_str", "expected"), [ # Mix of quotes, backslashes, and octal (r'"say \"hello\"\012new line"', 'say "hello"\nnew line'), (r'"path\\to\\file\011\011data"', "path\\to\\file\t\tdata"), # Complex mixed example (r'"\042quoted\042 and \134backslash\134"', '"quoted" and \\backslash\\'), # Escaped characters that aren't special (r'"\a\b\c"', "abc"), # \a, \b, \c -> a, b, c ], ) def test_unquote_mixed_escapes(input_str: str, expected: str) -> None: """Test _unquote with mixed escape sequences.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( ("input_str", "expected"), [ # String that starts with quote but doesn't end with one ('"not closed', '"not closed'), # String that ends with quote but doesn't start with one ('not opened"', 'not opened"'), # Multiple quotes ('"""', '"'), ('""""', '""'), # Backslash at the end without anything to escape (r'"ends with\"', "ends with\\"), # Empty escape (r'"test\"', "test\\"), # Just escaped characters (r'"\"\"\""', '"""'), ], ) def test_unquote_edge_cases(input_str: str, expected: str) -> None: """Test _unquote edge cases.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( ("input_str", "expected"), [ # JSON-like data (r'"{\"user\":\"john\",\"id\":123}"', '{"user":"john","id":123}'), # URL-encoded then quoted ('"hello%20world"', "hello%20world"), # Path with backslashes (Windows-style) (r'"C:\\Users\\John\\Documents"', "C:\\Users\\John\\Documents"), # Complex session data ( r'"session_data=\"user123\";expires=2024"', 'session_data="user123";expires=2024', ), ], ) def test_unquote_real_world_examples(input_str: str, expected: str) -> None: """Test _unquote with real-world cookie value examples.""" assert _unquote(input_str) == expected @pytest.mark.parametrize( "test_value", [ '""', '"simple"', r'"with \"quotes\""', r'"with \\backslash\\"', r'"\012newline"', r'"complex\042quote\134slash\012"', '"not-quoted', 'also-not-quoted"', r'"mixed\011\042\134test"', ], ) def test_unquote_compatibility_with_simplecookie(test_value: str) -> None: """Test that _unquote behaves like SimpleCookie's unquoting.""" assert _unquote(test_value) == simplecookie_unquote(test_value), ( f"Mismatch for {test_value!r}: " f"our={_unquote(test_value)!r}, " f"SimpleCookie={simplecookie_unquote(test_value)!r}" ) ================================================ FILE: tests/test_cookiejar.py ================================================ import datetime import heapq import itertools import logging import pickle import sys from http.cookies import BaseCookie, Morsel, SimpleCookie from operator import not_ from pathlib import Path from unittest import mock import pytest from freezegun import freeze_time from yarl import URL from aiohttp import CookieJar, DummyCookieJar from aiohttp.typedefs import LooseCookies def dump_cookiejar() -> bytes: # pragma: no cover """Create pickled data for test_pickle_format().""" cj = CookieJar() cj.update_cookies(_cookies_to_send()) return pickle.dumps(cj._cookies, pickle.HIGHEST_PROTOCOL) def _cookies_to_send() -> SimpleCookie: return SimpleCookie( "shared-cookie=first; " "domain-cookie=second; Domain=example.com; " "subdomain1-cookie=third; Domain=test1.example.com; " "subdomain2-cookie=fourth; Domain=test2.example.com; " "dotted-domain-cookie=fifth; Domain=.example.com; " "different-domain-cookie=sixth; Domain=different.org; " "secure-cookie=seventh; Domain=secure.com; Secure; " "no-path-cookie=eighth; Domain=pathtest.com; " "path1-cookie=ninth; Domain=pathtest.com; Path=/; " "path2-cookie=tenth; Domain=pathtest.com; Path=/one; " "path3-cookie=eleventh; Domain=pathtest.com; Path=/one/two; " "path4-cookie=twelfth; Domain=pathtest.com; Path=/one/two/; " "expires-cookie=thirteenth; Domain=expirestest.com; Path=/;" " Expires=Tue, 1 Jan 2999 12:00:00 GMT; " "max-age-cookie=fourteenth; Domain=maxagetest.com; Path=/;" " Max-Age=60; " "invalid-max-age-cookie=fifteenth; Domain=invalid-values.com; " " Max-Age=string; " "invalid-expires-cookie=sixteenth; Domain=invalid-values.com; " " Expires=string;" ) @pytest.fixture def cookies_to_send() -> SimpleCookie: return _cookies_to_send() @pytest.fixture def cookies_to_send_with_expired() -> SimpleCookie: return SimpleCookie( "shared-cookie=first; " "domain-cookie=second; Domain=example.com; " "subdomain1-cookie=third; Domain=test1.example.com; " "subdomain2-cookie=fourth; Domain=test2.example.com; " "dotted-domain-cookie=fifth; Domain=.example.com; " "different-domain-cookie=sixth; Domain=different.org; " "secure-cookie=seventh; Domain=secure.com; Secure; " "no-path-cookie=eighth; Domain=pathtest.com; " "path1-cookie=ninth; Domain=pathtest.com; Path=/; " "path2-cookie=tenth; Domain=pathtest.com; Path=/one; " "path3-cookie=eleventh; Domain=pathtest.com; Path=/one/two; " "path4-cookie=twelfth; Domain=pathtest.com; Path=/one/two/; " "expires-cookie=thirteenth; Domain=expirestest.com; Path=/;" " Expires=Tue, 1 Jan 1980 12:00:00 GMT; " "max-age-cookie=fourteenth; Domain=maxagetest.com; Path=/;" " Max-Age=60; " "invalid-max-age-cookie=fifteenth; Domain=invalid-values.com; " " Max-Age=string; " "invalid-expires-cookie=sixteenth; Domain=invalid-values.com; " " Expires=string;" ) @pytest.fixture def cookies_to_receive() -> SimpleCookie: return SimpleCookie( "unconstrained-cookie=first; Path=/; " "domain-cookie=second; Domain=example.com; Path=/; " "subdomain1-cookie=third; Domain=test1.example.com; Path=/; " "subdomain2-cookie=fourth; Domain=test2.example.com; Path=/; " "dotted-domain-cookie=fifth; Domain=.example.com; Path=/; " "different-domain-cookie=sixth; Domain=different.org; Path=/; " "no-path-cookie=seventh; Domain=pathtest.com; " "path-cookie=eighth; Domain=pathtest.com; Path=/somepath; " "wrong-path-cookie=ninth; Domain=pathtest.com; Path=somepath;" ) def test_date_parsing() -> None: parse_func = CookieJar._parse_date utc = datetime.timezone.utc assert parse_func("") is None # 70 -> 1970 assert ( parse_func("Tue, 1 Jan 70 00:00:00 GMT") == datetime.datetime(1970, 1, 1, tzinfo=utc).timestamp() ) # 10 -> 2010 assert ( parse_func("Tue, 1 Jan 10 00:00:00 GMT") == datetime.datetime(2010, 1, 1, tzinfo=utc).timestamp() ) # No day of week string assert ( parse_func("1 Jan 1970 00:00:00 GMT") == datetime.datetime(1970, 1, 1, tzinfo=utc).timestamp() ) # No timezone string assert ( parse_func("Tue, 1 Jan 1970 00:00:00") == datetime.datetime(1970, 1, 1, tzinfo=utc).timestamp() ) # No year assert parse_func("Tue, 1 Jan 00:00:00 GMT") is None # No month assert parse_func("Tue, 1 1970 00:00:00 GMT") is None # No day of month assert parse_func("Tue, Jan 1970 00:00:00 GMT") is None # No time assert parse_func("Tue, 1 Jan 1970 GMT") is None # Invalid day of month assert parse_func("Tue, 0 Jan 1970 00:00:00 GMT") is None # Invalid year assert parse_func("Tue, 1 Jan 1500 00:00:00 GMT") is None # Invalid time assert parse_func("Tue, 1 Jan 1970 77:88:99 GMT") is None def test_domain_matching() -> None: test_func = CookieJar._is_domain_match assert test_func("test.com", "test.com") assert test_func("test.com", "sub.test.com") assert not test_func("test.com", "") assert not test_func("test.com", "test.org") assert not test_func("diff-test.com", "test.com") assert not test_func("test.com", "diff-test.com") assert not test_func("test.com", "127.0.0.1") async def test_constructor( cookies_to_send: SimpleCookie, cookies_to_receive: SimpleCookie ) -> None: jar = CookieJar() jar.update_cookies(cookies_to_send) jar_cookies = SimpleCookie() for cookie in jar: dict.__setitem__(jar_cookies, cookie.key, cookie) expected_cookies = cookies_to_send assert jar_cookies == expected_cookies async def test_constructor_with_expired( cookies_to_send_with_expired: SimpleCookie, cookies_to_receive: SimpleCookie ) -> None: jar = CookieJar() jar.update_cookies(cookies_to_send_with_expired) jar_cookies = SimpleCookie() for cookie in jar: dict.__setitem__(jar_cookies, cookie.key, cookie) expected_cookies = cookies_to_send_with_expired assert jar_cookies != expected_cookies def test_save_load( tmp_path: Path, cookies_to_send: SimpleCookie, cookies_to_receive: SimpleCookie, ) -> None: file_path = Path(str(tmp_path)) / "aiohttp.test.cookie" # export cookie jar jar_save = CookieJar() jar_save.update_cookies(cookies_to_receive) jar_save.save(file_path=file_path) jar_load = CookieJar() jar_load.load(file_path=file_path) jar_test = SimpleCookie() for cookie in jar_load: jar_test[cookie.key] = cookie assert jar_test == cookies_to_receive def test_save_load_partitioned_cookies(tmp_path: Path) -> None: file_path = Path(str(tmp_path)) / "aiohttp.test2.cookie" # export cookie jar jar_save = CookieJar() jar_save.update_cookies_from_headers( ["session=cookie; Partitioned"], URL("https://example.com/") ) jar_save.save(file_path=file_path) jar_load = CookieJar() jar_load.load(file_path=file_path) assert jar_save._cookies == jar_load._cookies async def test_update_cookie_with_unicode_domain() -> None: cookies = ( "idna-domain-first=first; Domain=xn--9caa.com; Path=/;", "idna-domain-second=second; Domain=xn--9caa.com; Path=/;", ) jar = CookieJar() jar.update_cookies(SimpleCookie(cookies[0]), URL("http://éé.com/")) jar.update_cookies(SimpleCookie(cookies[1]), URL("http://xn--9caa.com/")) jar_test = SimpleCookie() for cookie in jar: jar_test[cookie.key] = cookie assert jar_test == SimpleCookie(" ".join(cookies)) async def test_filter_cookie_with_unicode_domain() -> None: jar = CookieJar() jar.update_cookies( SimpleCookie("idna-domain-first=first; Domain=xn--9caa.com; Path=/; ") ) assert len(jar.filter_cookies(URL("http://éé.com"))) == 1 assert len(jar.filter_cookies(URL("http://xn--9caa.com"))) == 1 async def test_filter_cookies_str_deprecated() -> None: jar = CookieJar() with pytest.deprecated_call( match="The method accepts yarl.URL instances only, got ", ): jar.filter_cookies("http://éé.com") # type: ignore[arg-type] @pytest.mark.parametrize( ("url", "expected_cookies"), ( ( "http://pathtest.com/one/two/", ( "no-path-cookie", "path1-cookie", "path2-cookie", "shared-cookie", "path3-cookie", "path4-cookie", ), ), ( "http://pathtest.com/one/two", ( "no-path-cookie", "path1-cookie", "path2-cookie", "shared-cookie", "path3-cookie", ), ), ( "http://pathtest.com/one/two/three/", ( "no-path-cookie", "path1-cookie", "path2-cookie", "shared-cookie", "path3-cookie", "path4-cookie", ), ), ( "http://test1.example.com/", ( "shared-cookie", "domain-cookie", "subdomain1-cookie", "dotted-domain-cookie", ), ), ( "http://pathtest.com/", ( "shared-cookie", "no-path-cookie", "path1-cookie", ), ), ), ) async def test_filter_cookies_with_domain_path_lookup_multilevelpath( url: str, expected_cookies: set[str], ) -> None: jar = CookieJar() cookie = SimpleCookie( "shared-cookie=first; " "domain-cookie=second; Domain=example.com; " "subdomain1-cookie=third; Domain=test1.example.com; " "subdomain2-cookie=fourth; Domain=test2.example.com; " "dotted-domain-cookie=fifth; Domain=.example.com; " "different-domain-cookie=sixth; Domain=different.org; " "secure-cookie=seventh; Domain=secure.com; Secure; " "no-path-cookie=eighth; Domain=pathtest.com; " "path1-cookie=ninth; Domain=pathtest.com; Path=/; " "path2-cookie=tenth; Domain=pathtest.com; Path=/one; " "path3-cookie=eleventh; Domain=pathtest.com; Path=/one/two; " "path4-cookie=twelfth; Domain=pathtest.com; Path=/one/two/; " "expires-cookie=thirteenth; Domain=expirestest.com; Path=/;" " Expires=Tue, 1 Jan 1980 12:00:00 GMT; " "max-age-cookie=fourteenth; Domain=maxagetest.com; Path=/;" " Max-Age=60; " "invalid-max-age-cookie=fifteenth; Domain=invalid-values.com; " " Max-Age=string; " "invalid-expires-cookie=sixteenth; Domain=invalid-values.com; " " Expires=string;" ) jar.update_cookies(cookie) cookies = jar.filter_cookies(URL(url)) assert len(cookies) == len(expected_cookies) for c in cookies: assert c in expected_cookies async def test_domain_filter_ip_cookie_send() -> None: jar = CookieJar() cookies = SimpleCookie( "shared-cookie=first; " "domain-cookie=second; Domain=example.com; " "subdomain1-cookie=third; Domain=test1.example.com; " "subdomain2-cookie=fourth; Domain=test2.example.com; " "dotted-domain-cookie=fifth; Domain=.example.com; " "different-domain-cookie=sixth; Domain=different.org; " "secure-cookie=seventh; Domain=secure.com; Secure; " "no-path-cookie=eighth; Domain=pathtest.com; " "path1-cookie=ninth; Domain=pathtest.com; Path=/; " "path2-cookie=tenth; Domain=pathtest.com; Path=/one; " "path3-cookie=eleventh; Domain=pathtest.com; Path=/one/two; " "path4-cookie=twelfth; Domain=pathtest.com; Path=/one/two/; " "expires-cookie=thirteenth; Domain=expirestest.com; Path=/;" " Expires=Tue, 1 Jan 1980 12:00:00 GMT; " "max-age-cookie=fourteenth; Domain=maxagetest.com; Path=/;" " Max-Age=60; " "invalid-max-age-cookie=fifteenth; Domain=invalid-values.com; " " Max-Age=string; " "invalid-expires-cookie=sixteenth; Domain=invalid-values.com; " " Expires=string;" ) jar.update_cookies(cookies) cookies_sent = jar.filter_cookies(URL("http://1.2.3.4/")).output(header="Cookie:") assert cookies_sent == "Cookie: shared-cookie=first" async def test_domain_filter_ip_cookie_receive( cookies_to_receive: SimpleCookie, ) -> None: jar = CookieJar() jar.update_cookies(cookies_to_receive, URL("http://1.2.3.4/")) assert len(jar) == 0 @pytest.mark.parametrize( ("cookies", "expected", "quote_bool"), [ ( "shared-cookie=first; ip-cookie=second; Domain=127.0.0.1;", "Cookie: ip-cookie=second\r\nCookie: shared-cookie=first", True, ), ('ip-cookie="second"; Domain=127.0.0.1;', 'Cookie: ip-cookie="second"', True), ("custom-cookie=value/one;", 'Cookie: custom-cookie="value/one"', True), ("custom-cookie=value1;", "Cookie: custom-cookie=value1", True), ("custom-cookie=value/one;", "Cookie: custom-cookie=value/one", False), ('foo="quoted_value"', 'Cookie: foo="quoted_value"', True), ('foo="quoted_value"; domain=127.0.0.1', 'Cookie: foo="quoted_value"', True), ], ids=( "IP domain preserved", "no shared cookie", "quoted cookie with special char", "quoted cookie w/o special char", "unquoted cookie with special char", "pre-quoted cookie", "pre-quoted cookie with domain", ), ) async def test_quotes_correctly_based_on_input( cookies: str, expected: str, quote_bool: bool ) -> None: jar = CookieJar(unsafe=True, quote_cookie=quote_bool) jar.update_cookies(SimpleCookie(cookies)) cookies_sent = jar.filter_cookies(URL("http://127.0.0.1/")).output(header="Cookie:") assert cookies_sent == expected async def test_ignore_domain_ending_with_dot() -> None: jar = CookieJar(unsafe=True) jar.update_cookies( SimpleCookie("cookie=val; Domain=example.com.;"), URL("http://www.example.com") ) cookies_sent = jar.filter_cookies(URL("http://www.example.com/")) assert cookies_sent.output(header="Cookie:") == "Cookie: cookie=val" cookies_sent = jar.filter_cookies(URL("http://example.com/")) assert cookies_sent.output(header="Cookie:") == "" class TestCookieJarSafe: @pytest.fixture(autouse=True) def setup_cookies( self, cookies_to_send_with_expired: SimpleCookie, cookies_to_receive: SimpleCookie, ) -> None: self.cookies_to_send = cookies_to_send_with_expired self.cookies_to_receive = cookies_to_receive def request_reply_with_same_url( self, url: str ) -> tuple["BaseCookie[str]", SimpleCookie]: jar = CookieJar() jar.update_cookies(self.cookies_to_send) cookies_sent = jar.filter_cookies(URL(url)) jar.clear() jar.update_cookies(self.cookies_to_receive, URL(url)) cookies_received = SimpleCookie() for cookie in jar: dict.__setitem__(cookies_received, cookie.key, cookie) jar.clear() return cookies_sent, cookies_received def timed_request( self, url: str, update_time: float, send_time: float ) -> "BaseCookie[str]": jar = CookieJar() freeze_update_time: datetime.datetime | datetime.timedelta freeze_send_time: datetime.datetime | datetime.timedelta if isinstance(update_time, int): freeze_update_time = datetime.timedelta(seconds=update_time) else: freeze_update_time = datetime.datetime.fromtimestamp(update_time) if isinstance(send_time, int): freeze_send_time = datetime.timedelta(seconds=send_time) else: freeze_send_time = datetime.datetime.fromtimestamp(send_time) with freeze_time(freeze_update_time): jar.update_cookies(self.cookies_to_send) with freeze_time(freeze_send_time): cookies_sent = jar.filter_cookies(URL(url)) jar.clear() return cookies_sent def test_domain_filter_same_host(self) -> None: cookies_sent, cookies_received = self.request_reply_with_same_url( "http://example.com/" ) assert set(cookies_sent.keys()) == { "shared-cookie", "domain-cookie", "dotted-domain-cookie", } assert set(cookies_received.keys()) == { "unconstrained-cookie", "domain-cookie", "dotted-domain-cookie", } def test_domain_filter_same_host_and_subdomain(self) -> None: cookies_sent, cookies_received = self.request_reply_with_same_url( "http://test1.example.com/" ) assert set(cookies_sent.keys()) == { "shared-cookie", "domain-cookie", "subdomain1-cookie", "dotted-domain-cookie", } assert set(cookies_received.keys()) == { "unconstrained-cookie", "domain-cookie", "subdomain1-cookie", "dotted-domain-cookie", } def test_domain_filter_same_host_diff_subdomain(self) -> None: cookies_sent, cookies_received = self.request_reply_with_same_url( "http://different.example.com/" ) assert set(cookies_sent.keys()) == { "shared-cookie", "domain-cookie", "dotted-domain-cookie", } assert set(cookies_received.keys()) == { "unconstrained-cookie", "domain-cookie", "dotted-domain-cookie", } def test_domain_filter_diff_host(self) -> None: cookies_sent, cookies_received = self.request_reply_with_same_url( "http://different.org/" ) assert set(cookies_sent.keys()) == {"shared-cookie", "different-domain-cookie"} assert set(cookies_received.keys()) == { "unconstrained-cookie", "different-domain-cookie", } def test_domain_filter_host_only(self, cookies_to_receive: SimpleCookie) -> None: jar = CookieJar() jar.update_cookies(cookies_to_receive, URL("http://example.com/")) sub_cookie = SimpleCookie("subdomain=spam; Path=/;") jar.update_cookies(sub_cookie, URL("http://foo.example.com/")) cookies_sent = jar.filter_cookies(URL("http://foo.example.com/")) assert "subdomain" in set(cookies_sent.keys()) assert "unconstrained-cookie" not in set(cookies_sent.keys()) def test_secure_filter(self) -> None: cookies_sent, _ = self.request_reply_with_same_url("http://secure.com/") assert set(cookies_sent.keys()) == {"shared-cookie"} cookies_sent, _ = self.request_reply_with_same_url("https://secure.com/") assert set(cookies_sent.keys()) == {"shared-cookie", "secure-cookie"} def test_path_filter_root(self) -> None: cookies_sent, _ = self.request_reply_with_same_url("http://pathtest.com/") assert set(cookies_sent.keys()) == { "shared-cookie", "no-path-cookie", "path1-cookie", } def test_path_filter_folder(self) -> None: cookies_sent, _ = self.request_reply_with_same_url("http://pathtest.com/one/") assert set(cookies_sent.keys()) == { "shared-cookie", "no-path-cookie", "path1-cookie", "path2-cookie", } def test_path_filter_file(self) -> None: cookies_sent, _ = self.request_reply_with_same_url( "http://pathtest.com/one/two" ) assert set(cookies_sent.keys()) == { "shared-cookie", "no-path-cookie", "path1-cookie", "path2-cookie", "path3-cookie", } def test_path_filter_subfolder(self) -> None: cookies_sent, _ = self.request_reply_with_same_url( "http://pathtest.com/one/two/" ) assert set(cookies_sent.keys()) == { "shared-cookie", "no-path-cookie", "path1-cookie", "path2-cookie", "path3-cookie", "path4-cookie", } def test_path_filter_subsubfolder(self) -> None: cookies_sent, _ = self.request_reply_with_same_url( "http://pathtest.com/one/two/three/" ) assert set(cookies_sent.keys()) == { "shared-cookie", "no-path-cookie", "path1-cookie", "path2-cookie", "path3-cookie", "path4-cookie", } def test_path_filter_different_folder(self) -> None: cookies_sent, _ = self.request_reply_with_same_url( "http://pathtest.com/hundred/" ) assert set(cookies_sent.keys()) == { "shared-cookie", "no-path-cookie", "path1-cookie", } def test_path_value(self) -> None: _, cookies_received = self.request_reply_with_same_url("http://pathtest.com/") assert set(cookies_received.keys()) == { "unconstrained-cookie", "no-path-cookie", "path-cookie", "wrong-path-cookie", } assert cookies_received["no-path-cookie"]["path"] == "/" assert cookies_received["path-cookie"]["path"] == "/somepath" assert cookies_received["wrong-path-cookie"]["path"] == "/" def test_expires(self) -> None: ts_before = datetime.datetime( 1975, 1, 1, tzinfo=datetime.timezone.utc ).timestamp() ts_after = datetime.datetime( 2030, 1, 1, tzinfo=datetime.timezone.utc ).timestamp() cookies_sent = self.timed_request( "http://expirestest.com/", ts_before, ts_before ) assert set(cookies_sent.keys()) == {"shared-cookie", "expires-cookie"} cookies_sent = self.timed_request( "http://expirestest.com/", ts_before, ts_after ) assert set(cookies_sent.keys()) == {"shared-cookie"} def test_max_age(self) -> None: cookies_sent = self.timed_request("http://maxagetest.com/", 1000, 1000) assert set(cookies_sent.keys()) == {"shared-cookie", "max-age-cookie"} cookies_sent = self.timed_request("http://maxagetest.com/", 1000, 2000) assert set(cookies_sent.keys()) == {"shared-cookie"} def test_invalid_values(self) -> None: cookies_sent, cookies_received = self.request_reply_with_same_url( "http://invalid-values.com/" ) assert set(cookies_sent.keys()) == { "shared-cookie", "invalid-max-age-cookie", "invalid-expires-cookie", } cookie = cookies_sent["invalid-max-age-cookie"] assert cookie["max-age"] == "" cookie = cookies_sent["invalid-expires-cookie"] assert cookie["expires"] == "" async def test_cookie_not_expired_when_added_after_removal(self) -> None: # Test case for https://github.com/aio-libs/aiohttp/issues/2084 timestamps = [ 533588.993, 533588.993, 533588.993, 533588.993, 533589.093, 533589.093, ] loop = mock.Mock() loop.time.side_effect = itertools.chain( timestamps, itertools.cycle([timestamps[-1]]) ) jar = CookieJar(unsafe=True) # Remove `foo` cookie. jar.update_cookies(SimpleCookie('foo=""; Max-Age=0')) # Set `foo` cookie to `bar`. jar.update_cookies(SimpleCookie('foo="bar"')) # Assert that there is a cookie. assert len(jar) == 1 async def test_path_filter_diff_folder_same_name(self) -> None: jar = CookieJar(unsafe=True) jar.update_cookies( SimpleCookie("path-cookie=zero; Domain=pathtest.com; Path=/; ") ) jar.update_cookies( SimpleCookie("path-cookie=one; Domain=pathtest.com; Path=/one; ") ) assert len(jar) == 2 jar_filtered = jar.filter_cookies(URL("http://pathtest.com/")) assert len(jar_filtered) == 1 assert jar_filtered["path-cookie"].value == "zero" jar_filtered = jar.filter_cookies(URL("http://pathtest.com/one")) assert len(jar_filtered) == 1 assert jar_filtered["path-cookie"].value == "one" async def test_path_filter_diff_folder_same_name_return_best_match_independent_from_put_order( self, ) -> None: jar = CookieJar(unsafe=True) jar.update_cookies( SimpleCookie("path-cookie=one; Domain=pathtest.com; Path=/one; ") ) jar.update_cookies( SimpleCookie("path-cookie=zero; Domain=pathtest.com; Path=/; ") ) jar.update_cookies( SimpleCookie("path-cookie=two; Domain=pathtest.com; Path=/second; ") ) assert len(jar) == 3 jar_filtered = jar.filter_cookies(URL("http://pathtest.com/")) assert len(jar_filtered) == 1 assert jar_filtered["path-cookie"].value == "zero" jar_filtered = jar.filter_cookies(URL("http://pathtest.com/second")) assert len(jar_filtered) == 1 assert jar_filtered["path-cookie"].value == "two" jar_filtered = jar.filter_cookies(URL("http://pathtest.com/one")) assert len(jar_filtered) == 1 assert jar_filtered["path-cookie"].value == "one" async def test_dummy_cookie_jar() -> None: cookie = SimpleCookie("foo=bar; Domain=example.com;") dummy_jar = DummyCookieJar() assert dummy_jar.quote_cookie is True assert len(dummy_jar) == 0 dummy_jar.update_cookies(cookie) assert len(dummy_jar) == 0 with pytest.raises(StopIteration): next(iter(dummy_jar)) assert not dummy_jar.filter_cookies(URL("http://example.com/")) dummy_jar.clear() async def test_loose_cookies_types() -> None: jar = CookieJar() accepted_types: tuple[LooseCookies, ...] = ( [("str", BaseCookie())], [("str", Morsel())], [("str", "str")], {"str": BaseCookie()}, {"str": Morsel()}, {"str": "str"}, SimpleCookie(), ) for loose_cookies_type in accepted_types: jar.update_cookies(cookies=loose_cookies_type) async def test_cookie_jar_clear_all() -> None: sut = CookieJar() cookie = SimpleCookie() cookie["foo"] = "bar" sut.update_cookies(cookie) sut.clear() assert len(sut) == 0 async def test_cookie_jar_clear_expired() -> None: sut = CookieJar() cookie = SimpleCookie() cookie["foo"] = "bar" cookie["foo"]["expires"] = "Tue, 1 Jan 1990 12:00:00 GMT" with freeze_time("1980-01-01"): sut.update_cookies(cookie) for _ in range(2): sut.clear(not_) with freeze_time("1980-01-01"): assert len(sut) == 0 async def test_cookie_jar_expired_changes() -> None: """Test that expire time changes are handled as expected.""" jar = CookieJar() cookie_eleven_am = SimpleCookie() cookie_eleven_am["foo"] = "bar" cookie_eleven_am["foo"]["expires"] = "Tue, 1 Jan 1990 11:00:00 GMT" cookie_noon = SimpleCookie() cookie_noon["foo"] = "bar" cookie_noon["foo"]["expires"] = "Tue, 1 Jan 1990 12:00:00 GMT" cookie_one_pm = SimpleCookie() cookie_one_pm["foo"] = "bar" cookie_one_pm["foo"]["expires"] = "Tue, 1 Jan 1990 13:00:00 GMT" cookie_two_pm = SimpleCookie() cookie_two_pm["foo"] = "bar" cookie_two_pm["foo"]["expires"] = "Tue, 1 Jan 1990 14:00:00 GMT" with freeze_time() as freezer: freezer.move_to("1990-01-01 10:00:00+00:00") jar.update_cookies(cookie_noon) assert len(jar) == 1 matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 1 assert "foo" in matched_cookies jar.update_cookies(cookie_eleven_am) matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 1 assert "foo" in matched_cookies jar.update_cookies(cookie_one_pm) matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 1 assert "foo" in matched_cookies jar.update_cookies(cookie_two_pm) matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 1 assert "foo" in matched_cookies freezer.move_to("1990-01-01 13:00:00+00:00") matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 1 assert "foo" in matched_cookies freezer.move_to("1990-01-01 14:00:00+00:00") matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 0 async def test_cookie_jar_duplicates_with_expire_heap() -> None: """Test that duplicate cookies do not grow the expires heap.""" jar = CookieJar() cookie_eleven_am = SimpleCookie() cookie_eleven_am["foo"] = "bar" cookie_eleven_am["foo"]["expires"] = "Tue, 1 Jan 1990 11:00:00 GMT" cookie_two_pm = SimpleCookie() cookie_two_pm["foo"] = "bar" cookie_two_pm["foo"]["expires"] = "Tue, 1 Jan 1990 14:00:00 GMT" with freeze_time() as freezer: freezer.move_to("1990-01-01 10:00:00+00:00") for _ in range(10): jar.update_cookies(cookie_eleven_am) assert len(jar) == 1 matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 1 assert "foo" in matched_cookies assert len(jar._expire_heap) == 1 freezer.move_to("1990-01-01 16:00:00+00:00") jar.update_cookies(cookie_two_pm) matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 0 assert len(jar._expire_heap) == 0 async def test_cookie_jar_filter_cookies_expires() -> None: """Test that calling filter_cookies will expire stale cookies.""" jar = CookieJar() assert len(jar) == 0 cookie = SimpleCookie() cookie["foo"] = "bar" cookie["foo"]["expires"] = "Tue, 1 Jan 1990 12:00:00 GMT" with freeze_time("1980-01-01"): jar.update_cookies(cookie) assert len(jar) == 1 # filter_cookies should expire stale cookies jar.filter_cookies(URL("http://any.com/")) assert len(jar) == 0 async def test_cookie_jar_heap_cleanup() -> None: """Test that the heap gets cleaned up when there are many old expirations.""" jar = CookieJar() # The heap should not be cleaned up when there are less than 100 expiration changes min_cookies_to_cleanup = 100 with freeze_time() as freezer: freezer.move_to("1990-01-01 09:00:00+00:00") start_time = datetime.datetime( 1990, 1, 1, 10, 0, 0, tzinfo=datetime.timezone.utc ) for i in range(min_cookies_to_cleanup): cookie = SimpleCookie() cookie["foo"] = "bar" cookie["foo"]["expires"] = ( start_time + datetime.timedelta(seconds=i) ).strftime("%a, %d %b %Y %H:%M:%S GMT") jar.update_cookies(cookie) assert len(jar._expire_heap) == i + 1 assert len(jar._expire_heap) == min_cookies_to_cleanup # Now that we reached the minimum number of cookies to cleanup, # add one more cookie to trigger the cleanup cookie = SimpleCookie() cookie["foo"] = "bar" cookie["foo"]["expires"] = ( start_time + datetime.timedelta(seconds=i + 1) ).strftime("%a, %d %b %Y %H:%M:%S GMT") jar.update_cookies(cookie) # Verify that the heap has been cleaned up assert len(jar) == 1 matched_cookies = jar.filter_cookies(URL("/")) assert len(matched_cookies) == 1 assert "foo" in matched_cookies # The heap should have been cleaned up assert len(jar._expire_heap) == 1 async def test_cookie_jar_heap_maintains_order_after_cleanup() -> None: """Test that order is maintained after cleanup.""" jar = CookieJar() # The heap should not be cleaned up when there are less than 100 expiration changes min_cookies_to_cleanup = 100 with freeze_time() as freezer: freezer.move_to("1990-01-01 09:00:00+00:00") for hour in (12, 13): for i in range(min_cookies_to_cleanup): cookie = SimpleCookie() cookie["foo"] = "bar" cookie["foo"]["domain"] = f"example{i}.com" cookie["foo"]["expires"] = f"Tue, 1 Jan 1990 {hour}:00:00 GMT" jar.update_cookies(cookie) # Get the jar into a state where the next cookie will trigger the cleanup assert len(jar._expire_heap) == min_cookies_to_cleanup * 2 assert len(jar._expirations) == min_cookies_to_cleanup cookie = SimpleCookie() cookie["foo"] = "bar" cookie["foo"]["domain"] = "example0.com" cookie["foo"]["expires"] = "Tue, 1 Jan 1990 14:00:00 GMT" jar.update_cookies(cookie) assert len(jar) == 100 # The heap should have been cleaned up assert len(jar._expire_heap) == 100 # Verify that the heap is still ordered heap_before = jar._expire_heap.copy() heapq.heapify(jar._expire_heap) assert heap_before == jar._expire_heap async def test_cookie_jar_clear_domain() -> None: sut = CookieJar() cookie = SimpleCookie() cookie["foo"] = "bar" cookie["domain_cookie"] = "value" cookie["domain_cookie"]["domain"] = "example.com" cookie["subdomain_cookie"] = "value" cookie["subdomain_cookie"]["domain"] = "test.example.com" sut.update_cookies(cookie) sut.clear_domain("example.com") iterator = iter(sut) morsel = next(iterator) assert morsel.key == "foo" assert morsel.value == "bar" with pytest.raises(StopIteration): next(iterator) def test_pickle_format(cookies_to_send: SimpleCookie) -> None: """Test if cookiejar pickle format breaks. If this test fails, it may indicate that saved cookiejars will stop working. If that happens then: 1. Avoid releasing the change in a bugfix release. 2. Try to include a migration script in the release notes (example below). 3. Use dump_cookiejar() at the top of this file to update `pickled`. Depending on the changes made, a migration script might look like: import pickle with file_path.open("rb") as f: cookies = pickle.load(f) morsels = [(name, m) for c in cookies.values() for name, m in c.items()] cookies.clear() for name, m in morsels: cookies[(m["domain"], m["path"])][name] = m with file_path.open("wb") as f: pickle.dump(cookies, f, pickle.HIGHEST_PROTOCOL) """ if sys.version_info < (3, 14): pickled = b"\x80\x04\x95\xc8\x0b\x00\x00\x00\x00\x00\x00\x8c\x0bcollections\x94\x8c\x0bdefaultdict\x94\x93\x94\x8c\x0chttp.cookies\x94\x8c\x0cSimpleCookie\x94\x93\x94\x85\x94R\x94(\x8c\x00\x94h\x08\x86\x94h\x05)\x81\x94\x8c\rshared-cookie\x94h\x03\x8c\x06Morsel\x94\x93\x94)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94\x8c\x01/\x94\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x08\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(\x8c\x03key\x94h\x0b\x8c\x05value\x94\x8c\x05first\x94\x8c\x0bcoded_value\x94h\x1cubs\x8c\x0bexample.com\x94h\x08\x86\x94h\x05)\x81\x94(\x8c\rdomain-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x1e\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah!h\x1b\x8c\x06second\x94h\x1dh-ub\x8c\x14dotted-domain-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94\x8c\x0bexample.com\x94\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah.h\x1b\x8c\x05fifth\x94h\x1dh;ubu\x8c\x11test1.example.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x11subdomain1-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h<\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah?h\x1b\x8c\x05third\x94h\x1dhKubs\x8c\x11test2.example.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x11subdomain2-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94hL\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ahOh\x1b\x8c\x06fourth\x94h\x1dh[ubs\x8c\rdifferent.org\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x17different-domain-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\\\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah_h\x1b\x8c\x05sixth\x94h\x1dhkubs\x8c\nsecure.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\rsecure-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94hl\x8c\x07max-age\x94h\x08\x8c\x06secure\x94\x88\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ahoh\x1b\x8c\x07seventh\x94h\x1dh{ubs\x8c\x0cpathtest.com\x94h\x08\x86\x94h\x05)\x81\x94(\x8c\x0eno-path-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h|\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\x7fh\x1b\x8c\x06eighth\x94h\x1dh\x8bub\x8c\x0cpath1-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94\x8c\x0cpathtest.com\x94\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\x8ch\x1b\x8c\x05ninth\x94h\x1dh\x99ubu\x8c\x0cpathtest.com\x94\x8c\x04/one\x94\x86\x94h\x05)\x81\x94\x8c\x0cpath2-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x9b\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x9a\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\x9eh\x1b\x8c\x05tenth\x94h\x1dh\xaaubs\x8c\x0cpathtest.com\x94\x8c\x08/one/two\x94\x86\x94h\x05)\x81\x94(\x8c\x0cpath3-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\xac\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\xab\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xafh\x1b\x8c\x08eleventh\x94h\x1dh\xbbub\x8c\x0cpath4-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94\x8c\t/one/two/\x94\x8c\x07comment\x94h\x08\x8c\x06domain\x94\x8c\x0cpathtest.com\x94\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xbch\x1b\x8c\x07twelfth\x94h\x1dh\xcaubu\x8c\x0fexpirestest.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x0eexpires-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94\x8c\x1cTue, 1 Jan 2999 12:00:00 GMT\x94\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\xcb\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xceh\x1b\x8c\nthirteenth\x94h\x1dh\xdbubs\x8c\x0emaxagetest.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x0emax-age-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\xdc\x8c\x07max-age\x94\x8c\x0260\x94\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xdfh\x1b\x8c\nfourteenth\x94h\x1dh\xecubs\x8c\x12invalid-values.com\x94h\x08\x86\x94h\x05)\x81\x94(\x8c\x16invalid-max-age-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\xed\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xf0h\x1b\x8c\tfifteenth\x94h\x1dh\xfcub\x8c\x16invalid-expires-cookie\x94h\r)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\x11\x8c\x07comment\x94h\x08\x8c\x06domain\x94\x8c\x12invalid-values.com\x94\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(h\x1ah\xfdh\x1b\x8c\tsixteenth\x94h\x1dj\n\x01\x00\x00ubuu." else: pickled = b'\x80\x05\x95\x06\x08\x00\x00\x00\x00\x00\x00\x8c\x0bcollections\x94\x8c\x0bdefaultdict\x94\x93\x94\x8c\x0chttp.cookies\x94\x8c\x0cSimpleCookie\x94\x93\x94\x85\x94R\x94(\x8c\x00\x94h\x08\x86\x94h\x05)\x81\x94\x8c\rshared-cookie\x94h\x03\x8c\x06Morsel\x94\x93\x94)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94\x8c\x01/\x94\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x08\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08\x8c\x0bpartitioned\x94h\x08u}\x94(\x8c\x03key\x94h\x0b\x8c\x05value\x94\x8c\x05first\x94\x8c\x0bcoded_value\x94h\x1dubs\x8c\x0bexample.com\x94h\x08\x86\x94h\x05)\x81\x94(\x8c\rdomain-cookie\x94h\r)\x81\x94(h\x0fh\x08h\x10h\x11h\x12h\x08h\x13h\x1fh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08h\x19h\x08u}\x94(h\x1bh"h\x1c\x8c\x06second\x94h\x1eh%ub\x8c\x14dotted-domain-cookie\x94h\r)\x81\x94(h\x0fh\x08h\x10h\x11h\x12h\x08h\x13\x8c\x0bexample.com\x94h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08h\x19h\x08u}\x94(h\x1bh&h\x1c\x8c\x05fifth\x94h\x1eh*ubu\x8c\x11test1.example.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x11subdomain1-cookie\x94h\r)\x81\x94(h\x0fh\x08h\x10h\x11h\x12h\x08h\x13h+h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08h\x19h\x08u}\x94(h\x1bh.h\x1c\x8c\x05third\x94h\x1eh1ubs\x8c\x11test2.example.com\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x11subdomain2-cookie\x94h\r)\x81\x94(h\x0fh\x08h\x10h\x11h\x12h\x08h\x13h2h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08h\x19h\x08u}\x94(h\x1bh5h\x1c\x8c\x06fourth\x94h\x1eh8ubs\x8c\rdifferent.org\x94h\x08\x86\x94h\x05)\x81\x94\x8c\x17different-domain-cookie\x94h\r)\x81\x94(h\x0fh\x08h\x10h\x11h\x12h\x08h\x13h9h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08h\x19h\x08u}\x94(h\x1bh None: jar = CookieJar(unsafe=True, treat_as_secure_origin=url) assert jar._treat_as_secure_origin == frozenset({URL("http://127.0.0.1")}) async def test_treat_as_secure_origin() -> None: endpoint = URL("http://127.0.0.1/") jar = CookieJar(unsafe=True, treat_as_secure_origin=[endpoint]) secure_cookie = SimpleCookie( "cookie-key=cookie-value; HttpOnly; Path=/; Secure", ) jar.update_cookies( secure_cookie, endpoint, ) assert len(jar) == 1 filtered_cookies = jar.filter_cookies(request_url=endpoint) assert len(filtered_cookies) == 1 async def test_filter_cookies_does_not_leak_memory() -> None: """Test that filter_cookies doesn't create empty cookie entries. Regression test for https://github.com/aio-libs/aiohttp/issues/11052 """ jar = CookieJar() # Set a cookie with Path=/ jar.update_cookies({"test_cookie": "value; Path=/"}, URL("http://example.com/")) # Check initial state assert len(jar) == 1 initial_storage_size = len(jar._cookies) initial_morsel_cache_size = len(jar._morsel_cache) # Make multiple requests with different paths paths = [ "/", "/api", "/api/v1", "/api/v1/users", "/api/v1/users/123", "/static/css/style.css", "/images/logo.png", ] for path in paths: url = URL(f"http://example.com{path}") filtered = jar.filter_cookies(url) # Should still get the cookie assert len(filtered) == 1 assert "test_cookie" in filtered # Storage size should not grow significantly # Only the shared cookie entry ('', '') may be added final_storage_size = len(jar._cookies) assert final_storage_size <= initial_storage_size + 1 # Verify _morsel_cache doesn't leak either # It should only have entries for domains/paths where cookies exist final_morsel_cache_size = len(jar._morsel_cache) assert final_morsel_cache_size <= initial_morsel_cache_size + 1 # Verify no empty entries were created for domain-path combinations for key, cookies in jar._cookies.items(): if key != ("", ""): # Skip the shared cookie entry assert len(cookies) > 0, f"Empty cookie entry found for {key}" # Verify _morsel_cache entries correspond to actual cookies for key, morsels in jar._morsel_cache.items(): assert key in jar._cookies, f"Orphaned morsel cache entry for {key}" assert len(morsels) > 0, f"Empty morsel cache entry found for {key}" def test_update_cookies_from_headers() -> None: """Test update_cookies_from_headers method.""" jar: CookieJar = CookieJar() url: URL = URL("http://example.com/path") # Test with simple cookies headers = [ "session-id=123456; Path=/", "user-pref=dark-mode; Domain=.example.com", "tracking=xyz789; Secure; HttpOnly", ] jar.update_cookies_from_headers(headers, url) # Verify all cookies were added to the jar assert len(jar) == 3 # Check cookies available for HTTP URL (secure cookie should be filtered out) filtered_http: BaseCookie[str] = jar.filter_cookies(url) assert len(filtered_http) == 2 assert "session-id" in filtered_http assert filtered_http["session-id"].value == "123456" assert "user-pref" in filtered_http assert filtered_http["user-pref"].value == "dark-mode" assert "tracking" not in filtered_http # Secure cookie not available on HTTP # Check cookies available for HTTPS URL (all cookies should be available) url_https: URL = URL("https://example.com/path") filtered_https: BaseCookie[str] = jar.filter_cookies(url_https) assert len(filtered_https) == 3 assert "tracking" in filtered_https assert filtered_https["tracking"].value == "xyz789" def test_update_cookies_from_headers_duplicate_names() -> None: """Test that duplicate cookie names with different domains are preserved.""" jar: CookieJar = CookieJar() url: URL = URL("http://www.example.com/") # Headers with duplicate names but different domains headers = [ "session-id=123456; Domain=.example.com; Path=/", "session-id=789012; Domain=.www.example.com; Path=/", "user-pref=light; Domain=.example.com", "user-pref=dark; Domain=sub.example.com", ] jar.update_cookies_from_headers(headers, url) # Should have 3 cookies (user-pref=dark for sub.example.com is rejected) assert len(jar) == 3 # Verify we have both session-id cookies all_cookies: list[Morsel[str]] = list(jar) session_ids: list[Morsel[str]] = [c for c in all_cookies if c.key == "session-id"] assert len(session_ids) == 2 # Check their domains are different domains: set[str] = {c["domain"] for c in session_ids} assert domains == {"example.com", "www.example.com"} def test_update_cookies_from_headers_invalid_cookies( caplog: pytest.LogCaptureFixture, ) -> None: """Test that invalid cookies are logged and skipped.""" jar: CookieJar = CookieJar() url: URL = URL("http://example.com/") # Mix of valid and invalid cookies headers = [ "valid-cookie=value123", "invalid,cookie=value; " # Comma character is not allowed "HttpOnly; Path=/", "another-valid=value456", ] # Enable logging for the client logger with caplog.at_level(logging.WARNING, logger="aiohttp.client"): jar.update_cookies_from_headers(headers, url) # Check that we logged warnings for invalid cookies assert "Can not load cookies" in caplog.text # Valid cookies should still be added assert len(jar) >= 2 # At least the two clearly valid cookies filtered: BaseCookie[str] = jar.filter_cookies(url) assert "valid-cookie" in filtered assert "another-valid" in filtered def test_update_cookies_from_headers_with_curly_braces() -> None: """Test that cookies with curly braces in names are now accepted (#2683).""" jar: CookieJar = CookieJar() url: URL = URL("http://example.com/") # Cookie names with curly braces should now be accepted headers = [ "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=" "{925EC0B8-CB17-4BEB-8A35-1033813B0523}; " "HttpOnly; Path=/", "regular-cookie=value123", ] jar.update_cookies_from_headers(headers, url) # Both cookies should be added assert len(jar) == 2 filtered: BaseCookie[str] = jar.filter_cookies(url) assert "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}" in filtered assert "regular-cookie" in filtered def test_update_cookies_from_headers_with_special_chars() -> None: """Test that cookies with various special characters are accepted.""" jar: CookieJar = CookieJar() url: URL = URL("http://example.com/") # Various special characters that should now be accepted headers = [ "cookie_with_parens=(value)=test123", "cookie-with-brackets[index]=value456", "cookie@with@at=value789", "cookie:with:colons=value000", ] jar.update_cookies_from_headers(headers, url) # All cookies should be added assert len(jar) == 4 filtered: BaseCookie[str] = jar.filter_cookies(url) assert "cookie_with_parens" in filtered assert "cookie-with-brackets[index]" in filtered assert "cookie@with@at" in filtered assert "cookie:with:colons" in filtered def test_update_cookies_from_headers_empty_list() -> None: """Test that empty header list is handled gracefully.""" jar: CookieJar = CookieJar() url: URL = URL("http://example.com/") # Should not raise any errors jar.update_cookies_from_headers([], url) assert len(jar) == 0 def test_update_cookies_from_headers_with_attributes() -> None: """Test cookies with various attributes are handled correctly.""" jar: CookieJar = CookieJar() url: URL = URL("https://secure.example.com/app/page") headers = [ "secure-cookie=value1; Secure; HttpOnly; SameSite=Strict", "expiring-cookie=value2; Max-Age=3600; Path=/app", "domain-cookie=value3; Domain=.example.com; Path=/", "dated-cookie=value4; Expires=Wed, 09 Jun 3024 10:18:14 GMT", ] jar.update_cookies_from_headers(headers, url) # All cookies should be stored assert len(jar) == 4 # Verify secure cookie (should work on HTTPS subdomain) # Note: cookies without explicit path get path from URL (/app) filtered_https_root: BaseCookie[str] = jar.filter_cookies( URL("https://secure.example.com/") ) assert len(filtered_https_root) == 1 # Only domain-cookie has Path=/ assert "domain-cookie" in filtered_https_root # Check app path filtered_https_app: BaseCookie[str] = jar.filter_cookies( URL("https://secure.example.com/app/") ) assert len(filtered_https_app) == 4 # All cookies match assert "secure-cookie" in filtered_https_app assert "expiring-cookie" in filtered_https_app assert "domain-cookie" in filtered_https_app assert "dated-cookie" in filtered_https_app # Secure cookie should not be available on HTTP filtered_http_app: BaseCookie[str] = jar.filter_cookies( URL("http://secure.example.com/app/") ) assert "secure-cookie" not in filtered_http_app assert "expiring-cookie" in filtered_http_app # Non-secure cookies still available assert "domain-cookie" in filtered_http_app assert "dated-cookie" in filtered_http_app def test_update_cookies_from_headers_preserves_existing() -> None: """Test that update_cookies_from_headers preserves existing cookies.""" jar: CookieJar = CookieJar() url: URL = URL("http://example.com/") # Add some initial cookies jar.update_cookies( { "existing1": "value1", "existing2": "value2", }, url, ) # Add more cookies via headers headers = [ "new-cookie1=value3", "new-cookie2=value4", ] jar.update_cookies_from_headers(headers, url) # Should have all 4 cookies assert len(jar) == 4 filtered: BaseCookie[str] = jar.filter_cookies(url) assert "existing1" in filtered assert "existing2" in filtered assert "new-cookie1" in filtered assert "new-cookie2" in filtered def test_update_cookies_from_headers_overwrites_same_cookie() -> None: """Test that cookies with same name/domain/path are overwritten.""" jar: CookieJar = CookieJar() url: URL = URL("http://example.com/") # Add initial cookie jar.update_cookies({"session": "old-value"}, url) # Update with new value via headers headers = ["session=new-value"] jar.update_cookies_from_headers(headers, url) # Should still have just 1 cookie with updated value assert len(jar) == 1 filtered: BaseCookie[str] = jar.filter_cookies(url) assert filtered["session"].value == "new-value" def test_dummy_cookie_jar_update_cookies_from_headers() -> None: """Test that DummyCookieJar ignores update_cookies_from_headers.""" jar: DummyCookieJar = DummyCookieJar() url: URL = URL("http://example.com/") headers = [ "cookie1=value1", "cookie2=value2", ] # Should not raise and should not store anything jar.update_cookies_from_headers(headers, url) assert len(jar) == 0 filtered: BaseCookie[str] = jar.filter_cookies(url) assert len(filtered) == 0 async def test_shared_cookie_cache_population() -> None: """Test that shared cookies are cached correctly.""" jar = CookieJar(unsafe=True) # Create a shared cookie (no domain/path restrictions) sc = SimpleCookie() sc["shared"] = "value" sc["shared"]["path"] = "/" # Will be stripped to "" # Update with empty URL to avoid domain being set jar.update_cookies(sc, URL()) # Verify cookie is stored at shared key assert ("", "") in jar._cookies assert "shared" in jar._cookies[("", "")] # Filter cookies to populate cache filtered = jar.filter_cookies(URL("http://example.com/")) assert "shared" in filtered assert filtered["shared"].value == "value" # Verify cache was populated assert ("", "") in jar._morsel_cache assert "shared" in jar._morsel_cache[("", "")] # Verify the cached morsel is the same one returned cached_morsel = jar._morsel_cache[("", "")]["shared"] assert cached_morsel is filtered["shared"] async def test_shared_cookie_cache_clearing_on_update() -> None: """Test that shared cookie cache is cleared when cookie is updated.""" jar = CookieJar(unsafe=True) # Create initial shared cookie sc = SimpleCookie() sc["shared"] = "value1" sc["shared"]["path"] = "/" jar.update_cookies(sc, URL()) # Filter to populate cache filtered1 = jar.filter_cookies(URL("http://example.com/")) assert filtered1["shared"].value == "value1" assert "shared" in jar._morsel_cache[("", "")] # Update the cookie with new value sc2 = SimpleCookie() sc2["shared"] = "value2" sc2["shared"]["path"] = "/" jar.update_cookies(sc2, URL()) # Verify cache was cleared assert "shared" not in jar._morsel_cache[("", "")] # Filter again to verify new value filtered2 = jar.filter_cookies(URL("http://example.com/")) assert filtered2["shared"].value == "value2" # Verify cache was repopulated with new value assert "shared" in jar._morsel_cache[("", "")] async def test_shared_cookie_cache_clearing_on_delete() -> None: """Test that shared cookie cache is cleared when cookies are deleted.""" jar = CookieJar(unsafe=True) # Create multiple shared cookies sc = SimpleCookie() sc["shared1"] = "value1" sc["shared1"]["path"] = "/" sc["shared2"] = "value2" sc["shared2"]["path"] = "/" jar.update_cookies(sc, URL()) # Filter to populate cache jar.filter_cookies(URL("http://example.com/")) assert "shared1" in jar._morsel_cache[("", "")] assert "shared2" in jar._morsel_cache[("", "")] # Delete one cookie using internal method jar._delete_cookies([("", "", "shared1")]) # Verify cookie and its cache entry were removed assert "shared1" not in jar._cookies[("", "")] assert "shared1" not in jar._morsel_cache[("", "")] # Verify other cookie remains assert "shared2" in jar._cookies[("", "")] assert "shared2" in jar._morsel_cache[("", "")] async def test_shared_cookie_cache_clearing_on_clear() -> None: """Test that shared cookie cache is cleared when jar is cleared.""" jar = CookieJar(unsafe=True) # Create shared and domain-specific cookies # Shared cookie sc1 = SimpleCookie() sc1["shared"] = "shared_value" sc1["shared"]["path"] = "/" jar.update_cookies(sc1, URL()) # Domain-specific cookie sc2 = SimpleCookie() sc2["domain_cookie"] = "domain_value" jar.update_cookies(sc2, URL("http://example.com/")) # Filter to populate caches jar.filter_cookies(URL("http://example.com/")) # Verify caches are populated assert ("", "") in jar._morsel_cache assert "shared" in jar._morsel_cache[("", "")] assert ("example.com", "") in jar._morsel_cache assert "domain_cookie" in jar._morsel_cache[("example.com", "")] # Clear all cookies jar.clear() # Verify all caches are cleared assert len(jar._morsel_cache) == 0 assert len(jar._cookies) == 0 # Verify filtering returns no cookies filtered = jar.filter_cookies(URL("http://example.com/")) assert len(filtered) == 0 async def test_shared_cookie_with_multiple_domains() -> None: """Test that shared cookies work across different domains.""" jar = CookieJar(unsafe=True) # Create a truly shared cookie sc = SimpleCookie() sc["universal"] = "everywhere" sc["universal"]["path"] = "/" jar.update_cookies(sc, URL()) # Test filtering for different domains domains = [ "http://example.com/", "http://test.org/", "http://localhost/", "http://192.168.1.1/", # IP address (requires unsafe=True) ] for domain_url in domains: filtered = jar.filter_cookies(URL(domain_url)) assert "universal" in filtered assert filtered["universal"].value == "everywhere" # Verify cache is reused efficiently assert ("", "") in jar._morsel_cache assert "universal" in jar._morsel_cache[("", "")] # === Security tests for restricted unpickler and JSON save/load === def test_load_rejects_malicious_pickle(tmp_path: Path) -> None: """Verify CookieJar.load() blocks arbitrary code execution via pickle. A crafted pickle payload using os.system (or any non-cookie class) must be rejected by the restricted unpickler. """ import os file_path = tmp_path / "malicious.pkl" class RCEPayload: def __reduce__(self) -> tuple[object, ...]: return (os.system, ("echo PWNED",)) with open(file_path, "wb") as f: pickle.dump(RCEPayload(), f, pickle.HIGHEST_PROTOCOL) jar = CookieJar() with pytest.raises(pickle.UnpicklingError, match="Forbidden class"): jar.load(file_path) def test_load_rejects_eval_payload(tmp_path: Path) -> None: """Verify CookieJar.load() blocks eval-based pickle payloads.""" file_path = tmp_path / "eval_payload.pkl" class EvalPayload: def __reduce__(self) -> tuple[object, ...]: return (eval, ("__import__('os').system('echo PWNED')",)) with open(file_path, "wb") as f: pickle.dump(EvalPayload(), f, pickle.HIGHEST_PROTOCOL) jar = CookieJar() with pytest.raises(pickle.UnpicklingError, match="Forbidden class"): jar.load(file_path) def test_load_rejects_subprocess_payload(tmp_path: Path) -> None: """Verify CookieJar.load() blocks subprocess-based pickle payloads.""" import subprocess file_path = tmp_path / "subprocess_payload.pkl" class SubprocessPayload: def __reduce__(self) -> tuple[object, ...]: return (subprocess.call, (["echo", "PWNED"],)) with open(file_path, "wb") as f: pickle.dump(SubprocessPayload(), f, pickle.HIGHEST_PROTOCOL) jar = CookieJar() with pytest.raises(pickle.UnpicklingError, match="Forbidden class"): jar.load(file_path) def test_load_falls_back_to_pickle( tmp_path: Path, cookies_to_receive: SimpleCookie, ) -> None: """Verify load() falls back to restricted pickle for legacy cookie files. Existing cookie files saved with older versions of aiohttp used pickle. load() should detect that the file is not JSON and fall back to the restricted pickle unpickler for backward compatibility. """ file_path = tmp_path / "legit.pkl" # Write a legacy pickle file directly (as old aiohttp save() would) jar_save = CookieJar() jar_save.update_cookies(cookies_to_receive) with file_path.open(mode="wb") as f: pickle.dump(jar_save._cookies, f, pickle.HIGHEST_PROTOCOL) jar_load = CookieJar() jar_load.load(file_path=file_path) jar_test = SimpleCookie() for cookie in jar_load: jar_test[cookie.key] = cookie assert jar_test == cookies_to_receive def test_save_load_json_roundtrip( tmp_path: Path, cookies_to_receive: SimpleCookie, ) -> None: """Verify save/load roundtrip preserves cookies via JSON format.""" file_path = tmp_path / "cookies.json" jar_save = CookieJar() jar_save.update_cookies(cookies_to_receive) jar_save.save(file_path=file_path) jar_load = CookieJar() jar_load.load(file_path=file_path) saved_cookies = SimpleCookie() for cookie in jar_save: saved_cookies[cookie.key] = cookie loaded_cookies = SimpleCookie() for cookie in jar_load: loaded_cookies[cookie.key] = cookie assert saved_cookies == loaded_cookies def test_save_load_json_partitioned_cookies(tmp_path: Path) -> None: """Verify save/load roundtrip works with partitioned cookies.""" file_path = tmp_path / "partitioned.json" jar_save = CookieJar() jar_save.update_cookies_from_headers( ["session=cookie; Partitioned"], URL("https://example.com/") ) jar_save.save(file_path=file_path) jar_load = CookieJar() jar_load.load(file_path=file_path) # Compare individual cookie values (same approach as test_save_load_partitioned_cookies) saved = list(jar_save) loaded = list(jar_load) assert len(saved) == len(loaded) for s, lo in zip(saved, loaded): assert s.key == lo.key assert s.value == lo.value assert s["domain"] == lo["domain"] assert s["path"] == lo["path"] def test_json_format_is_safe(tmp_path: Path) -> None: """Verify the JSON file format cannot execute code on load.""" import json file_path = tmp_path / "safe.json" # Write something that might look dangerous but is just data malicious_data = { "evil.com|/": { "session": { "key": "session", "value": "__import__('os').system('echo PWNED')", "coded_value": "__import__('os').system('echo PWNED')", } } } with open(file_path, "w") as f: json.dump(malicious_data, f) jar = CookieJar() jar.load(file_path=file_path) # The "malicious" string is just a cookie value, not executed code cookies = list(jar) assert len(cookies) == 1 assert cookies[0].value == "__import__('os').system('echo PWNED')" def test_save_load_json_secure_cookies(tmp_path: Path) -> None: """Verify save/load preserves Secure and HttpOnly flags.""" file_path = tmp_path / "secure.json" jar_save = CookieJar() jar_save.update_cookies_from_headers( ["token=abc123; Secure; HttpOnly; Path=/; Domain=example.com"], URL("https://example.com/"), ) jar_save.save(file_path=file_path) jar_load = CookieJar() jar_load.load(file_path=file_path) loaded_cookies = list(jar_load) assert len(loaded_cookies) == 1 cookie = loaded_cookies[0] assert cookie.key == "token" assert cookie.value == "abc123" assert cookie["secure"] is True assert cookie["httponly"] is True assert cookie["domain"] == "example.com" ================================================ FILE: tests/test_flowcontrol_streams.py ================================================ import asyncio from unittest import mock import pytest from aiohttp import streams from aiohttp.base_protocol import BaseProtocol @pytest.fixture def protocol() -> BaseProtocol: return mock.create_autospec(BaseProtocol, spec_set=True, instance=True, _reading_paused=False) # type: ignore[no-any-return] @pytest.fixture def stream( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol ) -> streams.StreamReader: return streams.StreamReader(protocol, limit=1, loop=loop) class TestFlowControlStreamReader: async def test_read(self, stream: streams.StreamReader) -> None: stream.feed_data(b"da") res = await stream.read(1) assert res == b"d" assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_read_resume_paused(self, stream: streams.StreamReader) -> None: stream.feed_data(b"test") stream._protocol._reading_paused = True res = await stream.read(1) assert res == b"t" assert stream._protocol.pause_reading.called # type: ignore[attr-defined] async def test_readline(self, stream: streams.StreamReader) -> None: stream.feed_data(b"d\n") res = await stream.readline() assert res == b"d\n" assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readline_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True stream.feed_data(b"d\n") res = await stream.readline() assert res == b"d\n" assert stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readany(self, stream: streams.StreamReader) -> None: stream.feed_data(b"data") res = await stream.readany() assert res == b"data" assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readany_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True stream.feed_data(b"data") res = await stream.readany() assert res == b"data" assert stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readchunk(self, stream: streams.StreamReader) -> None: stream.feed_data(b"data") res, end_of_http_chunk = await stream.readchunk() assert res == b"data" assert not end_of_http_chunk assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readchunk_resume_paused(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True stream.feed_data(b"data") res, end_of_http_chunk = await stream.readchunk() assert res == b"data" assert not end_of_http_chunk assert stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_readexactly(self, stream: streams.StreamReader) -> None: stream.feed_data(b"data") res = await stream.readexactly(3) assert res == b"dat" assert not stream._protocol.resume_reading.called # type: ignore[attr-defined] async def test_feed_data(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = False stream.feed_data(b"datadata") assert stream._protocol.pause_reading.called # type: ignore[attr-defined] async def test_read_nowait(self, stream: streams.StreamReader) -> None: stream._protocol._reading_paused = True stream.feed_data(b"data1") stream.feed_data(b"data2") stream.feed_data(b"data3") res = await stream.read(5) assert res == b"data1" assert stream._protocol.resume_reading.call_count == 0 # type: ignore[attr-defined] res = stream.read_nowait(5) assert res == b"data2" assert stream._protocol.resume_reading.call_count == 0 # type: ignore[attr-defined] res = stream.read_nowait(5) assert res == b"data3" assert stream._protocol.resume_reading.call_count == 1 # type: ignore[attr-defined] stream._protocol._reading_paused = False res = stream.read_nowait(5) assert res == b"" assert stream._protocol.resume_reading.call_count == 1 # type: ignore[attr-defined] async def test_resumed_on_eof(self, stream: streams.StreamReader) -> None: stream.feed_data(b"data") assert stream._protocol.pause_reading.call_count == 1 # type: ignore[attr-defined] assert stream._protocol.resume_reading.call_count == 0 # type: ignore[attr-defined] stream._protocol._reading_paused = True stream.feed_eof() assert stream._protocol.resume_reading.call_count == 1 # type: ignore[attr-defined] async def test_stream_reader_eof_when_full() -> None: loop = asyncio.get_event_loop() protocol = BaseProtocol(loop=loop) protocol.transport = asyncio.Transport() stream = streams.StreamReader(protocol, 1024, loop=loop) data_len = stream._high_water + 1 stream.feed_data(b"0" * data_len) assert protocol._reading_paused stream.feed_eof() assert not protocol._reading_paused ================================================ FILE: tests/test_formdata.py ================================================ import io from unittest import mock import pytest from aiohttp import FormData, web from aiohttp.http_writer import StreamWriter from aiohttp.pytest_plugin import AiohttpClient @pytest.fixture def buf() -> bytearray: return bytearray() @pytest.fixture def writer(buf: bytearray) -> StreamWriter: writer = mock.create_autospec(StreamWriter, spec_set=True) async def write(chunk: bytes) -> None: buf.extend(chunk) writer.write.side_effect = write return writer # type: ignore[no-any-return] def test_formdata_multipart(buf: bytearray) -> None: form = FormData(default_to_multipart=False) assert not form.is_multipart form.add_field("test", b"test", filename="test.txt") assert form.is_multipart def test_form_data_is_multipart_param(buf: bytearray) -> None: form = FormData(default_to_multipart=True) assert form.is_multipart form.add_field("test", "test") assert form.is_multipart @pytest.mark.parametrize("obj", (object(), None)) def test_invalid_formdata_payload_multipart(obj: object) -> None: form = FormData() form.add_field("test", obj, filename="test.txt") with pytest.raises(TypeError, match="Can not serialize value"): form() @pytest.mark.parametrize("obj", (object(), None)) def test_invalid_formdata_payload_urlencoded(obj: object) -> None: form = FormData({"test": obj}) with pytest.raises(TypeError, match="expected str"): form() def test_invalid_formdata_params() -> None: with pytest.raises(TypeError): FormData("asdasf") def test_invalid_formdata_params2() -> None: with pytest.raises(TypeError): FormData("as") # 2-char str is not allowed async def test_formdata_textio_charset(buf: bytearray, writer: StreamWriter) -> None: form = FormData() body = io.TextIOWrapper(io.BytesIO(b"\xe6\x97\xa5\xe6\x9c\xac"), encoding="utf-8") form.add_field("foo", body, content_type="text/plain; charset=shift-jis") payload = form() await payload.write(writer) assert b"charset=shift-jis" in buf assert b"\x93\xfa\x96{" in buf @pytest.mark.parametrize("val", (0, 0.1, {}, [], b"foo")) def test_invalid_type_formdata_content_type(val: object) -> None: form = FormData() with pytest.raises(TypeError): form.add_field("foo", "bar", content_type=val) # type: ignore[arg-type] @pytest.mark.parametrize("val", ("\r", "\n", "a\ra\n", "a\na\r")) def test_invalid_value_formdata_content_type(val: str) -> None: form = FormData() with pytest.raises(ValueError): form.add_field("foo", "bar", content_type=val) def test_invalid_formdata_filename() -> None: form = FormData() invalid_vals = [0, 0.1, {}, [], b"foo"] for invalid_val in invalid_vals: with pytest.raises(TypeError): form.add_field("foo", "bar", filename=invalid_val) # type: ignore[arg-type] async def test_formdata_field_name_is_quoted( buf: bytearray, writer: StreamWriter ) -> None: form = FormData(charset="ascii") form.add_field("email 1", "xxx@x.co", content_type="multipart/form-data") payload = form() await payload.write(writer) assert b'name="email\\ 1"' in buf async def test_formdata_field_name_is_not_quoted( buf: bytearray, writer: StreamWriter ) -> None: form = FormData(quote_fields=False, charset="ascii") form.add_field("email 1", "xxx@x.co", content_type="multipart/form-data") payload = form() await payload.write(writer) assert b'name="email 1"' in buf async def test_formdata_is_reusable(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.add_routes([web.post("/", handler)]) client = await aiohttp_client(app) data = FormData() data.add_field("test", "test_value", content_type="application/json") # First request resp1 = await client.post("/", data=data) assert resp1.status == 200 resp1.release() # Second request - should work without RuntimeError resp2 = await client.post("/", data=data) assert resp2.status == 200 resp2.release() # Third request to ensure continued reusability resp3 = await client.post("/", data=data) assert resp3.status == 200 resp3.release() async def test_formdata_boundary_param() -> None: boundary = "some_boundary" form = FormData(boundary=boundary) assert form._writer.boundary == boundary async def test_formdata_reusability_multipart( writer: StreamWriter, buf: bytearray ) -> None: form = FormData() form.add_field("name", "value") form.add_field("file", b"content", filename="test.txt", content_type="text/plain") # First call - should generate multipart payload payload1 = form() assert form.is_multipart buf.clear() await payload1.write(writer) result1 = bytes(buf) # Verify first result contains expected content assert b"name" in result1 assert b"value" in result1 assert b"test.txt" in result1 assert b"content" in result1 assert b"text/plain" in result1 # Second call - should generate identical multipart payload payload2 = form() buf.clear() await payload2.write(writer) result2 = bytes(buf) # Results should be identical (same boundary and content) assert result1 == result2 # Third call to ensure continued reusability payload3 = form() buf.clear() await payload3.write(writer) result3 = bytes(buf) assert result1 == result3 async def test_formdata_reusability_urlencoded( writer: StreamWriter, buf: bytearray ) -> None: form = FormData() form.add_field("key1", "value1") form.add_field("key2", "value2") # First call - should generate urlencoded payload payload1 = form() assert not form.is_multipart buf.clear() await payload1.write(writer) result1 = bytes(buf) # Verify first result contains expected content assert b"key1=value1" in result1 assert b"key2=value2" in result1 # Second call - should generate identical urlencoded payload payload2 = form() buf.clear() await payload2.write(writer) result2 = bytes(buf) # Results should be identical assert result1 == result2 # Third call to ensure continued reusability payload3 = form() buf.clear() await payload3.write(writer) result3 = bytes(buf) assert result1 == result3 async def test_formdata_reusability_after_adding_fields( writer: StreamWriter, buf: bytearray ) -> None: form = FormData() form.add_field("field1", "value1") # First call payload1 = form() buf.clear() await payload1.write(writer) result1 = bytes(buf) # Add more fields after first call form.add_field("field2", "value2") # Second call should include new field payload2 = form() buf.clear() await payload2.write(writer) result2 = bytes(buf) # Results should be different assert result1 != result2 assert b"field1=value1" in result2 assert b"field2=value2" in result2 assert b"field2=value2" not in result1 # Third call should be same as second payload3 = form() buf.clear() await payload3.write(writer) result3 = bytes(buf) assert result2 == result3 async def test_formdata_reusability_with_io_fields( writer: StreamWriter, buf: bytearray ) -> None: form = FormData() # Create BytesIO and StringIO objects bytes_io = io.BytesIO(b"bytes content") string_io = io.StringIO("string content") form.add_field( "bytes_field", bytes_io, filename="bytes.bin", content_type="application/octet-stream", ) form.add_field( "string_field", string_io, filename="text.txt", content_type="text/plain" ) # First call payload1 = form() buf.clear() await payload1.write(writer) result1 = bytes(buf) assert b"bytes content" in result1 assert b"string content" in result1 # Reset IO objects for reuse bytes_io.seek(0) string_io.seek(0) # Second call - should work with reset IO objects payload2 = form() buf.clear() await payload2.write(writer) result2 = bytes(buf) # Should produce identical results assert result1 == result2 ================================================ FILE: tests/test_helpers.py ================================================ import asyncio import base64 import datetime import gc import sys import weakref from collections.abc import Iterator from math import ceil, modf from pathlib import Path from types import MappingProxyType from unittest import mock from urllib.request import getproxies_environment import pytest from multidict import CIMultiDict, MultiDict, MultiDictProxy from yarl import URL from aiohttp import helpers, web from aiohttp.helpers import ( EMPTY_BODY_METHODS, is_expected_content_type, must_be_empty_body, parse_http_date, should_remove_content_length, ) # ------------------- parse_mimetype ---------------------------------- @pytest.mark.parametrize( "mimetype, expected", [ ("", helpers.MimeType("", "", "", MultiDictProxy(MultiDict()))), ("*", helpers.MimeType("*", "*", "", MultiDictProxy(MultiDict()))), ( "application/json", helpers.MimeType("application", "json", "", MultiDictProxy(MultiDict())), ), ( "application/json; charset=utf-8", helpers.MimeType( "application", "json", "", MultiDictProxy(MultiDict({"charset": "utf-8"})), ), ), ( """application/json; charset=utf-8;""", helpers.MimeType( "application", "json", "", MultiDictProxy(MultiDict({"charset": "utf-8"})), ), ), ( 'ApPlIcAtIoN/JSON;ChaRseT="UTF-8"', helpers.MimeType( "application", "json", "", MultiDictProxy(MultiDict({"charset": "UTF-8"})), ), ), ( "application/rss+xml", helpers.MimeType("application", "rss", "xml", MultiDictProxy(MultiDict())), ), ( "text/plain;base64", helpers.MimeType( "text", "plain", "", MultiDictProxy(MultiDict({"base64": ""})) ), ), ], ) def test_parse_mimetype(mimetype: str, expected: helpers.MimeType) -> None: result = helpers.parse_mimetype(mimetype) assert isinstance(result, helpers.MimeType) assert result == expected # ------------------- parse_content_type ------------------------------ @pytest.mark.parametrize( "content_type, expected", [ ( "text/plain", ("text/plain", MultiDictProxy(MultiDict())), ), ( "wrong", ("application/octet-stream", MultiDictProxy(MultiDict())), ), ], ) def test_parse_content_type( content_type: str, expected: tuple[str, MappingProxyType[str, str]] ) -> None: result = helpers.parse_content_type(content_type) assert result == expected # ------------------- guess_filename ---------------------------------- def test_guess_filename_with_file_object(tmp_path: Path) -> None: file_path = tmp_path / "test_guess_filename" with file_path.open("w+b") as fp: assert helpers.guess_filename(fp, "no-throw") is not None def test_guess_filename_with_path(tmp_path: Path) -> None: file_path = tmp_path / "test_guess_filename" assert helpers.guess_filename(file_path, "no-throw") is not None def test_guess_filename_with_default() -> None: assert helpers.guess_filename(None, "no-throw") == "no-throw" # ------------------- BasicAuth ----------------------------------- def test_basic_auth1() -> None: # missing password here with pytest.raises(ValueError): helpers.BasicAuth(None) # type: ignore[arg-type] def test_basic_auth2() -> None: with pytest.raises(ValueError): helpers.BasicAuth("nkim", None) # type: ignore[arg-type] def test_basic_with_auth_colon_in_login() -> None: with pytest.raises(ValueError): helpers.BasicAuth("nkim:1", "pwd") def test_basic_auth3() -> None: auth = helpers.BasicAuth("nkim") assert auth.login == "nkim" assert auth.password == "" def test_basic_auth4() -> None: auth = helpers.BasicAuth("nkim", "pwd") assert auth.login == "nkim" assert auth.password == "pwd" assert auth.encode() == "Basic bmtpbTpwd2Q=" @pytest.mark.parametrize( "header", ( "Basic bmtpbTpwd2Q=", "basic bmtpbTpwd2Q=", ), ) def test_basic_auth_decode(header: str) -> None: auth = helpers.BasicAuth.decode(header) assert auth.login == "nkim" assert auth.password == "pwd" def test_basic_auth_invalid() -> None: with pytest.raises(ValueError): helpers.BasicAuth.decode("bmtpbTpwd2Q=") def test_basic_auth_decode_not_basic() -> None: with pytest.raises(ValueError): helpers.BasicAuth.decode("Complex bmtpbTpwd2Q=") def test_basic_auth_decode_bad_base64() -> None: with pytest.raises(ValueError): helpers.BasicAuth.decode("Basic bmtpbTpwd2Q") @pytest.mark.parametrize("header", ("Basic ???", "Basic ")) def test_basic_auth_decode_illegal_chars_base64(header: str) -> None: with pytest.raises(ValueError, match="Invalid base64 encoding."): helpers.BasicAuth.decode(header) def test_basic_auth_decode_invalid_credentials() -> None: with pytest.raises(ValueError, match="Invalid credentials."): header = "Basic {}".format(base64.b64encode(b"username").decode()) helpers.BasicAuth.decode(header) @pytest.mark.parametrize( "credentials, expected_auth", ( (":", helpers.BasicAuth(login="", password="", encoding="latin1")), ( "username:", helpers.BasicAuth(login="username", password="", encoding="latin1"), ), ( ":password", helpers.BasicAuth(login="", password="password", encoding="latin1"), ), ( "username:password", helpers.BasicAuth(login="username", password="password", encoding="latin1"), ), ), ) def test_basic_auth_decode_blank_username( # type: ignore[misc] credentials: str, expected_auth: helpers.BasicAuth ) -> None: header = f"Basic {base64.b64encode(credentials.encode()).decode()}" assert helpers.BasicAuth.decode(header) == expected_auth def test_basic_auth_from_url() -> None: url = URL("http://user:pass@example.com") auth = helpers.BasicAuth.from_url(url) assert auth is not None assert auth.login == "user" assert auth.password == "pass" def test_basic_auth_no_user_from_url() -> None: url = URL("http://:pass@example.com") auth = helpers.BasicAuth.from_url(url) assert auth is not None assert auth.login == "" assert auth.password == "pass" def test_basic_auth_no_auth_from_url() -> None: url = URL("http://example.com") auth = helpers.BasicAuth.from_url(url) assert auth is None def test_basic_auth_from_not_url() -> None: with pytest.raises(TypeError): helpers.BasicAuth.from_url("http://user:pass@example.com") # type: ignore[arg-type] # ----------------------------------- is_ip_address() ---------------------- def test_is_ip_address() -> None: assert helpers.is_ip_address("127.0.0.1") assert helpers.is_ip_address("::1") assert helpers.is_ip_address("FE80:0000:0000:0000:0202:B3FF:FE1E:8329") # Hostnames assert not helpers.is_ip_address("localhost") assert not helpers.is_ip_address("www.example.com") def test_ipv4_addresses() -> None: ip_addresses = [ "0.0.0.0", "127.0.0.1", "255.255.255.255", ] for address in ip_addresses: assert helpers.is_ip_address(address) def test_ipv6_addresses() -> None: ip_addresses = [ "0:0:0:0:0:0:0:0", "FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF", "00AB:0002:3008:8CFD:00AB:0002:3008:8CFD", "00ab:0002:3008:8cfd:00ab:0002:3008:8cfd", "AB:02:3008:8CFD:AB:02:3008:8CFD", "AB:02:3008:8CFD::02:3008:8CFD", "::", "1::1", ] for address in ip_addresses: assert helpers.is_ip_address(address) def test_host_addresses() -> None: hosts = [ "www.four.part.host", "www.python.org", "foo.bar", "localhost", ] for host in hosts: assert not helpers.is_ip_address(host) def test_is_ip_address_invalid_type() -> None: with pytest.raises(TypeError): helpers.is_ip_address(123) # type: ignore[arg-type] with pytest.raises(TypeError): helpers.is_ip_address(object()) # type: ignore[arg-type] # ----------------------------------- TimeoutHandle ------------------- def test_timeout_handle(loop: asyncio.AbstractEventLoop) -> None: handle = helpers.TimeoutHandle(loop, 10.2) cb = mock.Mock() handle.register(cb) assert cb == handle._callbacks[0][0] handle.close() assert not handle._callbacks def test_when_timeout_smaller_second(loop: asyncio.AbstractEventLoop) -> None: timeout = 0.1 handle = helpers.TimeoutHandle(loop, timeout) timer = loop.time() + timeout start_handle = handle.start() assert start_handle is not None when = start_handle.when() handle.close() assert isinstance(when, float) assert when - timer == pytest.approx(0, abs=0.001) def test_when_timeout_smaller_second_with_low_threshold( loop: asyncio.AbstractEventLoop, ) -> None: timeout = 0.1 handle = helpers.TimeoutHandle(loop, timeout, 0.01) timer = loop.time() + timeout start_handle = handle.start() assert start_handle is not None when = start_handle.when() handle.close() assert isinstance(when, int) assert when == ceil(timer) def test_timeout_handle_cb_exc(loop: asyncio.AbstractEventLoop) -> None: handle = helpers.TimeoutHandle(loop, 10.2) cb = mock.Mock() handle.register(cb) cb.side_effect = ValueError() handle() assert cb.called assert not handle._callbacks def test_timer_context_not_cancelled() -> None: with mock.patch("aiohttp.helpers.asyncio") as m_asyncio: m_asyncio.TimeoutError = asyncio.TimeoutError loop = mock.Mock() ctx = helpers.TimerContext(loop) ctx.timeout() with pytest.raises(asyncio.TimeoutError): with ctx: pass assert not m_asyncio.current_task.return_value.cancel.called @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for .cancelling()" ) async def test_timer_context_timeout_does_not_leak_upward() -> None: """Verify that the TimerContext does not leak cancellation outside the context manager.""" loop = asyncio.get_running_loop() ctx = helpers.TimerContext(loop) current_task = asyncio.current_task() assert current_task is not None with pytest.raises(asyncio.TimeoutError): with ctx: assert current_task.cancelling() == 0 loop.call_soon(ctx.timeout) await asyncio.sleep(1) # After the context manager exits, the task should no longer be cancelling assert current_task.cancelling() == 0 @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for .cancelling()" ) async def test_timer_context_timeout_does_swallow_cancellation() -> None: """Verify that the TimerContext does not swallow cancellation.""" loop = asyncio.get_running_loop() current_task = asyncio.current_task() assert current_task is not None ctx = helpers.TimerContext(loop) async def task_with_timeout() -> None: new_task = asyncio.current_task() assert new_task is not None with pytest.raises(asyncio.TimeoutError): with ctx: assert new_task.cancelling() == 0 await asyncio.sleep(1) task = asyncio.create_task(task_with_timeout()) await asyncio.sleep(0) task.cancel() assert task.cancelling() == 1 ctx.timeout() # Cancellation should not leak into the current task assert current_task.cancelling() == 0 # Cancellation should not be swallowed if the task is cancelled # and it also times out await asyncio.sleep(0) with pytest.raises(asyncio.CancelledError): await task assert task.cancelling() == 1 def test_timer_context_no_task(loop: asyncio.AbstractEventLoop) -> None: with pytest.raises(RuntimeError): with helpers.TimerContext(loop): pass async def test_weakref_handle(loop: asyncio.AbstractEventLoop) -> None: cb = mock.Mock() helpers.weakref_handle(cb, "test", 0.01, loop) await asyncio.sleep(0.1) assert cb.test.called async def test_weakref_handle_with_small_threshold( loop: asyncio.AbstractEventLoop, ) -> None: cb = mock.Mock() loop = mock.Mock() loop.time.return_value = 10 helpers.weakref_handle(cb, "test", 0.1, loop, 0.01) loop.call_at.assert_called_with( 11, helpers._weakref_handle, (weakref.ref(cb), "test") ) async def test_weakref_handle_weak(loop: asyncio.AbstractEventLoop) -> None: cb = mock.Mock() helpers.weakref_handle(cb, "test", 0.01, loop) del cb gc.collect() await asyncio.sleep(0.1) # -------------------- ceil math ------------------------- def test_ceil_call_later() -> None: cb = mock.Mock() loop = mock.Mock() loop.time.return_value = 10.1 helpers.call_later(cb, 10.1, loop) loop.call_at.assert_called_with(21.0, cb) async def test_ceil_timeout_round(loop: asyncio.AbstractEventLoop) -> None: async with helpers.ceil_timeout(7.5) as cm: if sys.version_info >= (3, 11): w = cm.when() assert w is not None frac, integer = modf(w) else: assert cm.deadline is not None frac, integer = modf(cm.deadline) assert frac == 0 async def test_ceil_timeout_small(loop: asyncio.AbstractEventLoop) -> None: async with helpers.ceil_timeout(1.1) as cm: if sys.version_info >= (3, 11): w = cm.when() assert w is not None frac, integer = modf(w) else: assert cm.deadline is not None frac, integer = modf(cm.deadline) # a chance for exact integer with zero fraction is negligible assert frac != 0 def test_ceil_call_later_with_small_threshold() -> None: cb = mock.Mock() loop = mock.Mock() loop.time.return_value = 10.1 helpers.call_later(cb, 4.5, loop, 1) loop.call_at.assert_called_with(15, cb) def test_ceil_call_later_no_timeout() -> None: cb = mock.Mock() loop = mock.Mock() helpers.call_later(cb, 0, loop) assert not loop.call_at.called async def test_ceil_timeout_none(loop: asyncio.AbstractEventLoop) -> None: async with helpers.ceil_timeout(None) as cm: if sys.version_info >= (3, 11): assert cm.when() is None else: assert cm.deadline is None async def test_ceil_timeout_small_with_overriden_threshold( loop: asyncio.AbstractEventLoop, ) -> None: async with helpers.ceil_timeout(1.5, ceil_threshold=1) as cm: if sys.version_info >= (3, 11): w = cm.when() assert w is not None frac, integer = modf(w) else: assert cm.deadline is not None frac, integer = modf(cm.deadline) assert frac == 0 # -------------------------------- ContentDisposition ------------------- @pytest.mark.parametrize( "params, quote_fields, _charset, expected", [ (dict(foo="bar"), True, "utf-8", 'attachment; foo="bar"'), (dict(foo="bar[]"), True, "utf-8", 'attachment; foo="bar[]"'), (dict(foo=' a""b\\'), True, "utf-8", 'attachment; foo="\\ a\\"\\"b\\\\"'), (dict(foo="bär"), True, "utf-8", "attachment; foo*=utf-8''b%C3%A4r"), (dict(foo='bär "\\'), False, "utf-8", 'attachment; foo="bär \\"\\\\"'), (dict(foo="bär"), True, "latin-1", "attachment; foo*=latin-1''b%E4r"), (dict(filename="bär"), True, "utf-8", 'attachment; filename="b%C3%A4r"'), (dict(filename="bär"), True, "latin-1", 'attachment; filename="b%E4r"'), ( dict(filename='bär "\\'), False, "utf-8", 'attachment; filename="bär \\"\\\\"', ), ], ) def test_content_disposition( params: dict[str, str], quote_fields: bool, _charset: str, expected: str ) -> None: result = helpers.content_disposition_header( "attachment", quote_fields=quote_fields, _charset=_charset, params=params ) assert result == expected def test_content_disposition_bad_type() -> None: with pytest.raises(ValueError): helpers.content_disposition_header("foo bar") with pytest.raises(ValueError): helpers.content_disposition_header("—Ç–µ—Å—Ç") with pytest.raises(ValueError): helpers.content_disposition_header("foo\x00bar") with pytest.raises(ValueError): helpers.content_disposition_header("") def test_set_content_disposition_bad_param() -> None: with pytest.raises(ValueError): helpers.content_disposition_header("inline", params={"foo bar": "baz"}) with pytest.raises(ValueError): helpers.content_disposition_header("inline", params={"—Ç–µ—Å—Ç": "baz"}) with pytest.raises(ValueError): helpers.content_disposition_header("inline", params={"": "baz"}) with pytest.raises(ValueError): helpers.content_disposition_header("inline", params={"foo\x00bar": "baz"}) # --------------------- proxies_from_env ------------------------------ @pytest.mark.parametrize( ("proxy_env_vars", "url_input", "expected_scheme"), ( ({"http_proxy": "http://aiohttp.io/path"}, "http://aiohttp.io/path", "http"), ({"https_proxy": "http://aiohttp.io/path"}, "http://aiohttp.io/path", "https"), ({"ws_proxy": "http://aiohttp.io/path"}, "http://aiohttp.io/path", "ws"), ({"wss_proxy": "http://aiohttp.io/path"}, "http://aiohttp.io/path", "wss"), ), indirect=["proxy_env_vars"], ids=("http", "https", "ws", "wss"), ) @pytest.mark.usefixtures("proxy_env_vars") def test_proxies_from_env(url_input: str, expected_scheme: str) -> None: url = URL(url_input) ret = helpers.proxies_from_env() assert ret.keys() == {expected_scheme} assert ret[expected_scheme].proxy == url assert ret[expected_scheme].proxy_auth is None @pytest.mark.parametrize( ("proxy_env_vars", "url_input", "expected_scheme"), ( ( {"https_proxy": "https://aiohttp.io/path"}, "https://aiohttp.io/path", "https", ), ({"wss_proxy": "wss://aiohttp.io/path"}, "wss://aiohttp.io/path", "wss"), ), indirect=["proxy_env_vars"], ids=("https", "wss"), ) @pytest.mark.usefixtures("proxy_env_vars") def test_proxies_from_env_skipped( caplog: pytest.LogCaptureFixture, url_input: str, expected_scheme: str ) -> None: url = URL(url_input) assert helpers.proxies_from_env() == {} assert len(caplog.records) == 1 log_message = ( f"{expected_scheme.upper()!s} proxies {url!s} are not supported, ignoring" ) assert caplog.record_tuples == [("aiohttp.client", 30, log_message)] @pytest.mark.parametrize( ("proxy_env_vars", "url_input", "expected_scheme"), ( ( {"http_proxy": "http://user:pass@aiohttp.io/path"}, "http://user:pass@aiohttp.io/path", "http", ), ), indirect=["proxy_env_vars"], ids=("http",), ) @pytest.mark.usefixtures("proxy_env_vars") def test_proxies_from_env_http_with_auth(url_input: str, expected_scheme: str) -> None: url = URL("http://user:pass@aiohttp.io/path") ret = helpers.proxies_from_env() assert ret.keys() == {expected_scheme} assert ret[expected_scheme].proxy == url.with_user(None) proxy_auth = ret[expected_scheme].proxy_auth assert proxy_auth is not None assert proxy_auth.login == "user" assert proxy_auth.password == "pass" assert proxy_auth.encoding == "latin1" # --------------------- get_env_proxy_for_url ------------------------------ @pytest.fixture def proxy_env_vars( monkeypatch: pytest.MonkeyPatch, request: pytest.FixtureRequest ) -> object: for schema in getproxies_environment().keys(): monkeypatch.delenv(f"{schema}_proxy", False) for proxy_type, proxy_list in request.param.items(): monkeypatch.setenv(proxy_type, proxy_list) return request.param @pytest.mark.parametrize( ("proxy_env_vars", "url_input", "expected_err_msg"), ( ( {"no_proxy": "aiohttp.io"}, "http://aiohttp.io/path", r"Proxying is disallowed for `'aiohttp.io'`", ), ( {"no_proxy": "aiohttp.io,proxy.com"}, "http://aiohttp.io/path", r"Proxying is disallowed for `'aiohttp.io'`", ), ( {"http_proxy": "http://example.com"}, "https://aiohttp.io/path", r"No proxies found for `https://aiohttp.io/path` in the env", ), ( {"https_proxy": "https://example.com"}, "http://aiohttp.io/path", r"No proxies found for `http://aiohttp.io/path` in the env", ), ( {}, "https://aiohttp.io/path", r"No proxies found for `https://aiohttp.io/path` in the env", ), ( {"https_proxy": "https://example.com"}, "", r"No proxies found for `` in the env", ), ), indirect=["proxy_env_vars"], ids=( "url_matches_the_no_proxy_list", "url_matches_the_no_proxy_list_multiple", "url_scheme_does_not_match_http_proxy_list", "url_scheme_does_not_match_https_proxy_list", "no_proxies_are_set", "url_is_empty", ), ) @pytest.mark.usefixtures("proxy_env_vars") def test_get_env_proxy_for_url_negative(url_input: str, expected_err_msg: str) -> None: url = URL(url_input) with pytest.raises(LookupError, match=expected_err_msg): helpers.get_env_proxy_for_url(url) @pytest.mark.parametrize( ("proxy_env_vars", "url_input"), ( ({"http_proxy": "http://example.com"}, "http://aiohttp.io/path"), ({"https_proxy": "http://example.com"}, "https://aiohttp.io/path"), ( {"http_proxy": "http://example.com,http://proxy.org"}, "http://aiohttp.io/path", ), ), indirect=["proxy_env_vars"], ids=( "url_scheme_match_http_proxy_list", "url_scheme_match_https_proxy_list", "url_scheme_match_http_proxy_list_multiple", ), ) def test_get_env_proxy_for_url(proxy_env_vars: dict[str, str], url_input: str) -> None: url = URL(url_input) proxy, proxy_auth = helpers.get_env_proxy_for_url(url) proxy_list = proxy_env_vars[url.scheme + "_proxy"] assert proxy == URL(proxy_list) assert proxy_auth is None # ------------- set_result / set_exception ---------------------- async def test_set_result(loop: asyncio.AbstractEventLoop) -> None: fut = loop.create_future() helpers.set_result(fut, 123) assert 123 == await fut async def test_set_result_cancelled(loop: asyncio.AbstractEventLoop) -> None: fut = loop.create_future() fut.cancel() helpers.set_result(fut, 123) with pytest.raises(asyncio.CancelledError): await fut async def test_set_exception(loop: asyncio.AbstractEventLoop) -> None: fut = loop.create_future() helpers.set_exception(fut, RuntimeError()) with pytest.raises(RuntimeError): await fut async def test_set_exception_cancelled(loop: asyncio.AbstractEventLoop) -> None: fut = loop.create_future() fut.cancel() helpers.set_exception(fut, RuntimeError()) with pytest.raises(asyncio.CancelledError): await fut # ----------- ChainMapProxy -------------------------- AppKeyDict = dict[str | web.AppKey[object], object] class TestChainMapProxy: def test_inheritance(self) -> None: with pytest.raises(TypeError): class A(helpers.ChainMapProxy): # type: ignore[misc] pass def test_getitem(self) -> None: d1: AppKeyDict = {"a": 2, "b": 3} d2: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d1, d2]) assert cp["a"] == 2 assert cp["b"] == 3 def test_getitem_not_found(self) -> None: d: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d]) with pytest.raises(KeyError): cp["b"] def test_get(self) -> None: d1: AppKeyDict = {"a": 2, "b": 3} d2: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d1, d2]) assert cp.get("a") == 2 def test_get_default(self) -> None: d1: AppKeyDict = {"a": 2, "b": 3} d2: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d1, d2]) assert cp.get("c", 4) == 4 def test_get_non_default(self) -> None: d1: AppKeyDict = {"a": 2, "b": 3} d2: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d1, d2]) assert cp.get("a", 4) == 2 def test_len(self) -> None: d1: AppKeyDict = {"a": 2, "b": 3} d2: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d1, d2]) assert len(cp) == 2 def test_iter(self) -> None: d1: AppKeyDict = {"a": 2, "b": 3} d2: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d1, d2]) assert set(cp) == {"a", "b"} def test_contains(self) -> None: d1: AppKeyDict = {"a": 2, "b": 3} d2: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d1, d2]) assert "a" in cp assert "b" in cp assert "c" not in cp def test_bool(self) -> None: assert helpers.ChainMapProxy([{"a": 1}]) assert not helpers.ChainMapProxy([{}, {}]) assert not helpers.ChainMapProxy([]) def test_repr(self) -> None: d1: AppKeyDict = {"a": 2, "b": 3} d2: AppKeyDict = {"a": 1} cp = helpers.ChainMapProxy([d1, d2]) expected = f"ChainMapProxy({d1!r}, {d2!r})" assert expected == repr(cp) def test_is_expected_content_type_json_match_exact() -> None: expected_ct = "application/json" response_ct = "application/json" assert is_expected_content_type( response_content_type=response_ct, expected_content_type=expected_ct ) def test_is_expected_content_type_json_match_partially() -> None: expected_ct = "application/json" response_ct = "application/alto-costmap+json" # mime-type from rfc7285 assert is_expected_content_type( response_content_type=response_ct, expected_content_type=expected_ct ) def test_is_expected_content_type_non_application_json_suffix() -> None: expected_ct = "application/json" response_ct = "model/gltf+json" # rfc 6839 assert is_expected_content_type( response_content_type=response_ct, expected_content_type=expected_ct ) def test_is_expected_content_type_non_application_json_private_suffix() -> None: expected_ct = "application/json" response_ct = "x-foo/bar+json" # rfc 6839 assert is_expected_content_type( response_content_type=response_ct, expected_content_type=expected_ct ) def test_is_expected_content_type_json_non_lowercase() -> None: """Per RFC 2045, media type matching is case insensitive.""" expected_ct = "application/json" response_ct = "Application/JSON" assert is_expected_content_type( response_content_type=response_ct, expected_content_type=expected_ct ) def test_is_expected_content_type_json_trailing_chars() -> None: expected_ct = "application/json" response_ct = "application/json-seq" assert not is_expected_content_type( response_content_type=response_ct, expected_content_type=expected_ct ) def test_is_expected_content_type_non_json_match_exact() -> None: expected_ct = "text/javascript" response_ct = "text/javascript" assert is_expected_content_type( response_content_type=response_ct, expected_content_type=expected_ct ) def test_is_expected_content_type_non_json_not_match() -> None: expected_ct = "application/json" response_ct = "text/plain" assert not is_expected_content_type( response_content_type=response_ct, expected_content_type=expected_ct ) # It's necessary to subclass CookieMixin before using it. # See the comments on its __slots__. class CookieImplementation(helpers.CookieMixin): pass def test_cookies_mixin() -> None: sut = CookieImplementation() assert sut.cookies == {} assert str(sut.cookies) == "" sut.set_cookie("name", "value") assert str(sut.cookies) == "Set-Cookie: name=value; Path=/" sut.set_cookie("name", "") assert str(sut.cookies) == 'Set-Cookie: name=""; Path=/' sut.set_cookie("name", "value") assert str(sut.cookies) == "Set-Cookie: name=value; Path=/" sut.set_cookie("name", "other_value") assert str(sut.cookies) == "Set-Cookie: name=other_value; Path=/" sut.cookies["name"] = "another_other_value" sut.cookies["name"]["max-age"] = 10 assert ( str(sut.cookies) == "Set-Cookie: name=another_other_value; Max-Age=10; Path=/" ) sut.del_cookie("name") expected = ( 'Set-Cookie: name=""; ' "expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; Path=/" ) assert str(sut.cookies) == expected sut.del_cookie("name") assert str(sut.cookies) == expected sut.set_cookie("name", "value", domain="local.host") expected = "Set-Cookie: name=value; Domain=local.host; Path=/" assert str(sut.cookies) == expected def test_cookies_mixin_path() -> None: sut = CookieImplementation() assert sut.cookies == {} sut.set_cookie("name", "value", path="/some/path") assert str(sut.cookies) == "Set-Cookie: name=value; Path=/some/path" sut.set_cookie("name", "value", expires="123") assert str(sut.cookies) == "Set-Cookie: name=value; expires=123; Path=/" sut.set_cookie( "name", "value", domain="example.com", path="/home", expires="123", max_age="10", secure=True, httponly=True, samesite="lax", ) assert ( str(sut.cookies).lower() == "set-cookie: name=value; " "domain=example.com; " "expires=123; " "httponly; " "max-age=10; " "path=/home; " "samesite=lax; " "secure" ) @pytest.mark.skipif(sys.version_info < (3, 14), reason="No partitioned support") def test_cookies_mixin_partitioned() -> None: sut = CookieImplementation() assert sut.cookies == {} sut.set_cookie("name", "value", partitioned=False) assert str(sut.cookies) == "Set-Cookie: name=value; Path=/" sut.set_cookie("name", "value", partitioned=True) assert str(sut.cookies) == "Set-Cookie: name=value; Partitioned; Path=/" def test_sutonse_cookie__issue_del_cookie() -> None: sut = CookieImplementation() assert sut.cookies == {} assert str(sut.cookies) == "" sut.del_cookie("name") expected = ( 'Set-Cookie: name=""; ' "expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; Path=/" ) assert str(sut.cookies) == expected def test_cookie_set_after_del() -> None: sut = CookieImplementation() sut.del_cookie("name") sut.set_cookie("name", "val") # check for Max-Age dropped expected = "Set-Cookie: name=val; Path=/" assert str(sut.cookies) == expected def test_populate_with_cookies() -> None: cookies_mixin = CookieImplementation() cookies_mixin.set_cookie("name", "value") headers = CIMultiDict[str]() helpers.populate_with_cookies(headers, cookies_mixin.cookies) assert headers == CIMultiDict({"Set-Cookie": "name=value; Path=/"}) @pytest.mark.parametrize( ["value", "expected"], [ # email.utils.parsedate returns None pytest.param("xxyyzz", None), # datetime.datetime fails with ValueError("year 4446413 is out of range") pytest.param("Tue, 08 Oct 4446413 00:56:40 GMT", None), # datetime.datetime fails with ValueError("second must be in 0..59") pytest.param("Tue, 08 Oct 2000 00:56:80 GMT", None), # OK pytest.param( "Tue, 08 Oct 2000 00:56:40 GMT", datetime.datetime(2000, 10, 8, 0, 56, 40, tzinfo=datetime.timezone.utc), ), # OK (ignore timezone and overwrite to UTC) pytest.param( "Tue, 08 Oct 2000 00:56:40 +0900", datetime.datetime(2000, 10, 8, 0, 56, 40, tzinfo=datetime.timezone.utc), ), ], ) def test_parse_http_date(value: str, expected: datetime.datetime | None) -> None: assert parse_http_date(value) == expected @pytest.mark.parametrize( ["netrc_contents", "expected_username"], [ ( "machine example.com login username password pass\n", "username", ), ], indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") def test_netrc_from_env(expected_username: str) -> None: """Test that reading netrc files from env works as expected""" netrc_obj = helpers.netrc_from_env() assert netrc_obj is not None auth = netrc_obj.authenticators("example.com") assert auth is not None assert auth[0] == expected_username @pytest.fixture def protected_dir(tmp_path: Path) -> Iterator[Path]: protected_dir = tmp_path / "protected" protected_dir.mkdir() try: protected_dir.chmod(0o600) yield protected_dir finally: protected_dir.rmdir() def test_netrc_from_home_does_not_raise_if_access_denied( protected_dir: Path, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setattr(Path, "home", lambda: protected_dir) monkeypatch.delenv("NETRC", raising=False) helpers.netrc_from_env() @pytest.mark.parametrize( ["netrc_contents", "expected_auth"], [ ( "machine example.com login username password pass\n", helpers.BasicAuth("username", "pass"), ), ( "machine example.com account username password pass\n", helpers.BasicAuth("username", "pass"), ), ( "machine example.com password pass\n", helpers.BasicAuth("", "pass"), ), ], indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") def test_basicauth_present_in_netrc( # type: ignore[misc] expected_auth: helpers.BasicAuth, ) -> None: """Test that netrc file contents are properly parsed into BasicAuth tuples""" netrc_obj = helpers.netrc_from_env() assert expected_auth == helpers.basicauth_from_netrc(netrc_obj, "example.com") @pytest.mark.parametrize( ["netrc_contents"], [ ("",), ], indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") def test_read_basicauth_from_empty_netrc() -> None: """Test that an error is raised if netrc doesn't have an entry for our host""" netrc_obj = helpers.netrc_from_env() with pytest.raises( LookupError, match="No entry for example.com found in the `.netrc` file." ): helpers.basicauth_from_netrc(netrc_obj, "example.com") def test_method_must_be_empty_body() -> None: """Test that HEAD is the only method that unequivocally must have an empty body.""" assert "HEAD" in EMPTY_BODY_METHODS # CONNECT is only empty on a successful response assert "CONNECT" not in EMPTY_BODY_METHODS def test_should_remove_content_length_is_subset_of_must_be_empty_body() -> None: """Test should_remove_content_length is always a subset of must_be_empty_body.""" assert should_remove_content_length("GET", 101) is True assert must_be_empty_body("GET", 101) is True assert should_remove_content_length("GET", 102) is True assert must_be_empty_body("GET", 102) is True assert should_remove_content_length("GET", 204) is True assert must_be_empty_body("GET", 204) is True assert should_remove_content_length("GET", 204) is True assert must_be_empty_body("GET", 204) is True assert should_remove_content_length("GET", 200) is False assert must_be_empty_body("GET", 200) is False assert should_remove_content_length("HEAD", 200) is False assert must_be_empty_body("HEAD", 200) is True # CONNECT is only empty on a successful response assert should_remove_content_length("CONNECT", 200) is True assert must_be_empty_body("CONNECT", 200) is True assert should_remove_content_length("CONNECT", 201) is True assert must_be_empty_body("CONNECT", 201) is True assert should_remove_content_length("CONNECT", 300) is False assert must_be_empty_body("CONNECT", 300) is False ================================================ FILE: tests/test_http_exceptions.py ================================================ # Tests for http_exceptions.py import pickle from multidict import CIMultiDict from aiohttp import http_exceptions class TestHttpProcessingError: def test_ctor(self) -> None: err = http_exceptions.HttpProcessingError( code=500, message="Internal error", headers=CIMultiDict() ) assert err.code == 500 assert err.message == "Internal error" assert err.headers == CIMultiDict() def test_pickle(self) -> None: err = http_exceptions.HttpProcessingError( code=500, message="Internal error", headers=CIMultiDict() ) err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.code == 500 assert err2.message == "Internal error" assert err2.headers == CIMultiDict() assert err2.foo == "bar" def test_str(self) -> None: err = http_exceptions.HttpProcessingError( code=500, message="Internal error", headers=CIMultiDict() ) assert str(err) == "500, message:\n Internal error" def test_repr(self) -> None: err = http_exceptions.HttpProcessingError( code=500, message="Internal error", headers=CIMultiDict() ) assert repr(err) == ("") class TestBadHttpMessage: def test_ctor(self) -> None: err = http_exceptions.BadHttpMessage("Bad HTTP message", headers=CIMultiDict()) assert err.code == 400 assert err.message == "Bad HTTP message" assert err.headers == CIMultiDict() def test_pickle(self) -> None: err = http_exceptions.BadHttpMessage( message="Bad HTTP message", headers=CIMultiDict() ) err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.code == 400 assert err2.message == "Bad HTTP message" assert err2.headers == {} assert err2.foo == "bar" def test_str(self) -> None: err = http_exceptions.BadHttpMessage( message="Bad HTTP message", headers=CIMultiDict() ) assert str(err) == "400, message:\n Bad HTTP message" def test_repr(self) -> None: err = http_exceptions.BadHttpMessage( message="Bad HTTP message", headers=CIMultiDict() ) assert repr(err) == "" class TestLineTooLong: def test_ctor(self) -> None: err = http_exceptions.LineTooLong(b"spam", 10) assert err.code == 400 assert err.message == "Got more than 10 bytes when reading: b'spam'." assert err.headers is None def test_pickle(self) -> None: err = http_exceptions.LineTooLong(line=b"spam", limit=10) err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.code == 400 assert err2.message == ("Got more than 10 bytes when reading: b'spam'.") assert err2.headers is None assert err2.foo == "bar" def test_str(self) -> None: err = http_exceptions.LineTooLong(line=b"spam", limit=10) expected = "400, message:\n Got more than 10 bytes when reading: b'spam'." assert str(err) == expected def test_repr(self) -> None: err = http_exceptions.LineTooLong(line=b"spam", limit=10) assert repr(err) == ( '" ) class TestInvalidHeader: def test_ctor(self) -> None: err = http_exceptions.InvalidHeader("X-Spam") assert err.code == 400 assert err.message == "Invalid HTTP header: 'X-Spam'" assert err.headers is None def test_pickle(self) -> None: err = http_exceptions.InvalidHeader(hdr="X-Spam") err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.code == 400 assert err2.message == "Invalid HTTP header: 'X-Spam'" assert err2.headers is None assert err2.foo == "bar" def test_str(self) -> None: err = http_exceptions.InvalidHeader(hdr="X-Spam") assert str(err) == "400, message:\n Invalid HTTP header: 'X-Spam'" def test_repr(self) -> None: err = http_exceptions.InvalidHeader(hdr="X-Spam") expected = "" assert repr(err) == expected class TestBadStatusLine: def test_ctor(self) -> None: err = http_exceptions.BadStatusLine("Test") assert err.line == "Test" assert str(err) == "400, message:\n Bad status line 'Test'" def test_ctor2(self) -> None: err = http_exceptions.BadStatusLine("") assert err.line == "" assert str(err) == "400, message:\n Bad status line ''" def test_pickle(self) -> None: err = http_exceptions.BadStatusLine("Test") err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.line == "Test" assert err2.foo == "bar" ================================================ FILE: tests/test_http_parser.py ================================================ # Tests for aiohttp/protocol.py import asyncio import re import sys import zlib from collections.abc import Iterable from contextlib import suppress from typing import Any from unittest import mock from urllib.parse import quote import pytest from multidict import CIMultiDict from yarl import URL import aiohttp from aiohttp import http_exceptions, streams from aiohttp.base_protocol import BaseProtocol from aiohttp.helpers import NO_EXTENSIONS from aiohttp.http_parser import ( DeflateBuffer, HeadersParser, HttpParser, HttpPayloadParser, HttpRequestParser, HttpRequestParserPy, HttpResponseParser, HttpResponseParserPy, ) from aiohttp.http_writer import HttpVersion try: try: import brotlicffi as brotli except ImportError: import brotli except ImportError: # pragma: no cover brotli = None try: if sys.version_info >= (3, 14): import compression.zstd as zstandard # noqa: I900 else: import backports.zstd as zstandard except ImportError: zstandard = None # type: ignore[assignment] REQUEST_PARSERS = [HttpRequestParserPy] RESPONSE_PARSERS = [HttpResponseParserPy] with suppress(ImportError): from aiohttp.http_parser import HttpRequestParserC, HttpResponseParserC REQUEST_PARSERS.append(HttpRequestParserC) RESPONSE_PARSERS.append(HttpResponseParserC) @pytest.fixture def protocol() -> Any: return mock.create_autospec(BaseProtocol, spec_set=True, instance=True) def _gen_ids(parsers: Iterable[type[HttpParser[Any]]]) -> list[str]: return [ "py-parser" if parser.__module__ == "aiohttp.http_parser" else "c-parser" for parser in parsers ] @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) def parser( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol, request: pytest.FixtureRequest, ) -> HttpRequestParser: # Parser implementations return request.param( # type: ignore[no-any-return] protocol, loop, 2**16, max_line_size=8190, max_headers=128, max_field_size=8190, ) @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) def request_cls(request: pytest.FixtureRequest) -> type[HttpRequestParser]: # Request Parser class return request.param # type: ignore[no-any-return] @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) def response( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol, request: pytest.FixtureRequest, ) -> HttpResponseParser: # Parser implementations return request.param( # type: ignore[no-any-return] protocol, loop, 2**16, max_line_size=8190, max_headers=128, max_field_size=8190, read_until_eof=True, ) @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) def response_cls(request: pytest.FixtureRequest) -> type[HttpResponseParser]: # Parser implementations return request.param # type: ignore[no-any-return] @pytest.mark.skipif(NO_EXTENSIONS, reason="Extensions available but not imported") def test_c_parser_loaded() -> None: assert "HttpRequestParserC" in dir(aiohttp.http_parser) assert "HttpResponseParserC" in dir(aiohttp.http_parser) assert "RawRequestMessageC" in dir(aiohttp.http_parser) assert "RawResponseMessageC" in dir(aiohttp.http_parser) def test_parse_headers(parser: HttpRequestParser) -> None: text = b"""GET /test HTTP/1.1\r test: a line\r test2: data\r \r """ messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 msg = messages[0][0] assert list(msg.headers.items()) == [("test", "a line"), ("test2", "data")] assert msg.raw_headers == ((b"test", b"a line"), (b"test2", b"data")) assert not msg.should_close assert msg.compression is None assert not msg.upgrade def test_reject_obsolete_line_folding(parser: HttpRequestParser) -> None: text = b"""GET /test HTTP/1.1\r test: line\r Content-Length: 48\r test2: data\r \r """ with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") def test_invalid_character( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol, request: pytest.FixtureRequest, ) -> None: parser = HttpRequestParserC( protocol, loop, 2**16, max_line_size=8190, max_field_size=8190, ) text = b"POST / HTTP/1.1\r\nHost: localhost:8080\r\nSet-Cookie: abc\x01def\r\n\r\n" error_detail = re.escape(r""": b'Set-Cookie: abc\x01def' ^""") with pytest.raises(http_exceptions.BadHttpMessage, match=error_detail): parser.feed_data(text) @pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") def test_invalid_linebreak( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol, request: pytest.FixtureRequest, ) -> None: parser = HttpRequestParserC( protocol, loop, 2**16, max_line_size=8190, max_field_size=8190, ) text = b"GET /world HTTP/1.1\r\nHost: 127.0.0.1\n\r\n" error_detail = re.escape(r""": b'Host: 127.0.0.1\n' ^""") with pytest.raises(http_exceptions.BadHttpMessage, match=error_detail): parser.feed_data(text) def test_cve_2023_37276(parser: HttpRequestParser) -> None: text = b"""POST / HTTP/1.1\r\nHost: localhost:8080\r\nX-Abc: \rxTransfer-Encoding: chunked\r\n\r\n""" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @pytest.mark.parametrize( "rfc9110_5_6_2_token_delim", r'"(),/:;<=>?@[\]{}', ) def test_bad_header_name( parser: HttpRequestParser, rfc9110_5_6_2_token_delim: str ) -> None: text = f"POST / HTTP/1.1\r\nhead{rfc9110_5_6_2_token_delim}er: val\r\n\r\n".encode() if rfc9110_5_6_2_token_delim == ":": # Inserting colon into header just splits name/value earlier. parser.feed_data(text) return with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @pytest.mark.parametrize( "hdr", ( "Content-Length: -5", # https://www.rfc-editor.org/rfc/rfc9110.html#name-content-length "Content-Length: +256", "Content-Length: \N{SUPERSCRIPT ONE}", "Content-Length: \N{MATHEMATICAL DOUBLE-STRUCK DIGIT ONE}", "Foo: abc\rdef", # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-5 "Bar: abc\ndef", "Baz: abc\x00def", "Foo : bar", # https://www.rfc-editor.org/rfc/rfc9112.html#section-5.1-2 "Foo\t: bar", "\xffoo: bar", "Foo: abc\x01def", # CTL bytes forbidden per RFC 9110 §5.5 "Foo: abc\x7fdef", # DEL is also a CTL byte "Foo: abc\x1fdef", ), ) def test_bad_headers(parser: HttpRequestParser, hdr: str) -> None: text = f"POST / HTTP/1.1\r\n{hdr}\r\n\r\n".encode() with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_ctl_host_header_bad_characters(parser: HttpRequestParser) -> None: """CTL byte in Host header must be rejected.""" text = b"GET /test HTTP/1.1\r\nHost: trusted.example\x01@bad.test\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_unpaired_surrogate_in_header_py( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol ) -> None: parser = HttpRequestParserPy( protocol, loop, 2**16, max_line_size=8190, max_field_size=8190, ) text = b"POST / HTTP/1.1\r\n\xff\r\n\r\n" message = None try: parser.feed_data(text) except http_exceptions.InvalidHeader as e: message = e.message.encode("utf-8") assert message is not None def test_content_length_transfer_encoding(parser: HttpRequestParser) -> None: text = ( b"GET / HTTP/1.1\r\nHost: a\r\nContent-Length: 5\r\nTransfer-Encoding: a\r\n\r\n" + b"apple\r\n" ) with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @pytest.mark.parametrize( "hdr", ( "Content-Length", "Content-Location", "Content-Range", "Content-Type", "ETag", "Host", "Max-Forwards", "Server", "Transfer-Encoding", "User-Agent", ), ) def test_duplicate_singleton_header_rejected( parser: HttpRequestParser, hdr: str ) -> None: val1, val2 = ("1", "2") if hdr == "Content-Length" else ("value1", "value2") text = ( f"GET /test HTTP/1.1\r\n" f"Host: example.com\r\n" f"{hdr}: {val1}\r\n" f"{hdr}: {val2}\r\n" f"\r\n" ).encode() with pytest.raises(http_exceptions.BadHttpMessage, match="Duplicate"): parser.feed_data(text) def test_duplicate_host_header_rejected(parser: HttpRequestParser) -> None: text = ( b"GET /admin HTTP/1.1\r\n" b"Host: admin.example\r\n" b"Host: public.example\r\n" b"\r\n" ) with pytest.raises(http_exceptions.BadHttpMessage, match="Duplicate.*Host"): parser.feed_data(text) def test_bad_chunked(parser: HttpRequestParser) -> None: """Test that invalid chunked encoding doesn't allow content-length to be used.""" text = ( b"GET / HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked\r\n\r\n0_2e\r\n\r\n" + b"GET / HTTP/1.1\r\nHost: a\r\nContent-Length: 5\r\n\r\n0\r\n\r\n" ) with pytest.raises(http_exceptions.BadHttpMessage, match="0_2e"): parser.feed_data(text) def test_whitespace_before_header(parser: HttpRequestParser) -> None: text = b"GET / HTTP/1.1\r\n\tContent-Length: 1\r\n\r\nX" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @pytest.fixture def xfail_c_parser_status(request: pytest.FixtureRequest) -> None: if isinstance(request.getfixturevalue("parser"), HttpRequestParserPy): return request.node.add_marker( pytest.mark.xfail( reason="Regression test for Py parser. May match C behaviour later.", raises=http_exceptions.BadStatusLine, ) ) @pytest.mark.usefixtures("xfail_c_parser_status") def test_parse_unusual_request_line(parser: HttpRequestParser) -> None: text = b"#smol //a HTTP/1.3\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 msg, _ = messages[0] assert msg.compression is None assert not msg.upgrade assert msg.method == "#smol" assert msg.path == "//a" assert msg.version == (1, 3) def test_parse(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 msg, _ = messages[0] assert msg.compression is None assert not msg.upgrade assert msg.method == "GET" assert msg.path == "/test" assert msg.version == (1, 1) async def test_parse_body(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 _, payload = messages[0] body = await payload.read(4) assert body == b"body" async def test_parse_body_with_CRLF(parser: HttpRequestParser) -> None: text = b"\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 _, payload = messages[0] body = await payload.read(4) assert body == b"body" def test_parse_delayed(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\n" messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 0 assert not upgrade messages, upgrade, tail = parser.feed_data(b"\r\n") assert len(messages) == 1 msg = messages[0][0] assert msg.method == "GET" def test_headers_multi_feed(parser: HttpRequestParser) -> None: text1 = b"GET /test HTTP/1.1\r\n" text2 = b"test: line" text3 = b" continue\r\n\r\n" messages, upgrade, tail = parser.feed_data(text1) assert len(messages) == 0 messages, upgrade, tail = parser.feed_data(text2) assert len(messages) == 0 messages, upgrade, tail = parser.feed_data(text3) assert len(messages) == 1 msg = messages[0][0] assert list(msg.headers.items()) == [("test", "line continue")] assert msg.raw_headers == ((b"test", b"line continue"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade def test_headers_split_field(parser: HttpRequestParser) -> None: text1 = b"GET /test HTTP/1.1\r\n" text2 = b"t" text3 = b"es" text4 = b"t: value\r\n\r\n" messages, upgrade, tail = parser.feed_data(text1) messages, upgrade, tail = parser.feed_data(text2) messages, upgrade, tail = parser.feed_data(text3) assert len(messages) == 0 messages, upgrade, tail = parser.feed_data(text4) assert len(messages) == 1 msg = messages[0][0] assert list(msg.headers.items()) == [("test", "value")] assert msg.raw_headers == ((b"test", b"value"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade def test_parse_headers_multi(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" b"Set-Cookie: c1=cookie1\r\n" b"Set-Cookie: c2=cookie2\r\n\r\n" ) messages, upgrade, tail = parser.feed_data(text) assert len(messages) == 1 msg = messages[0][0] assert list(msg.headers.items()) == [ ("Set-Cookie", "c1=cookie1"), ("Set-Cookie", "c2=cookie2"), ] assert msg.raw_headers == ( (b"Set-Cookie", b"c1=cookie1"), (b"Set-Cookie", b"c2=cookie2"), ) assert not msg.should_close assert msg.compression is None def test_conn_default_1_0(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close def test_conn_default_1_1(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close def test_conn_close(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nconnection: close\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close def test_conn_close_1_0(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.0\r\nconnection: close\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close def test_conn_keep_alive_1_0(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close def test_conn_keep_alive_1_1(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nconnection: keep-alive\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close def test_conn_close_comma_list(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nconnection: close, keep-alive\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close def test_conn_close_multiple_headers(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" b"connection: keep-alive\r\n" b"connection: close\r\n\r\n" ) messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close def test_conn_other_1_0(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.0\r\nconnection: test\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.should_close def test_conn_other_1_1(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nconnection: test\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close def test_request_chunked(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert msg.chunked assert not upgrade assert isinstance(payload, streams.StreamReader) def test_te_header_non_ascii(parser: HttpRequestParser) -> None: # K = Kelvin sign, not valid ascii. text = "GET /test HTTP/1.1\r\nTransfer-Encoding: chunKed\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text.encode()) def test_upgrade_header_non_ascii(parser: HttpRequestParser) -> None: # K = Kelvin sign, not valid ascii. text = "GET /test HTTP/1.1\r\nUpgrade: websocKet\r\n\r\n" messages, upgrade, tail = parser.feed_data(text.encode()) assert not upgrade def test_request_te_chunked_with_content_length(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" b"content-length: 1234\r\n" b"transfer-encoding: chunked\r\n\r\n" ) with pytest.raises( http_exceptions.BadHttpMessage, match="Transfer-Encoding can't be present with Content-Length", ): parser.feed_data(text) def test_request_te_chunked123(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked123\r\n\r\n" with pytest.raises( http_exceptions.BadHttpMessage, match="Request has invalid `Transfer-Encoding`", ): parser.feed_data(text) async def test_request_te_last_chunked(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: not, chunked\r\n\r\n1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3 assert await messages[0][1].read() == b"Test" def test_request_te_first_chunked(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked, not\r\n\r\n1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.3 with pytest.raises( http_exceptions.BadHttpMessage, match="nvalid `Transfer-Encoding`", ): parser.feed_data(text) def test_request_te_duplicate_chunked(parser: HttpRequestParser) -> None: """Reject duplicate chunked Transfer-Encoding per RFC 9112 section 7.1.""" text = b"POST / HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked, chunked\r\n\r\n0\r\n\r\n" # https://www.rfc-editor.org/rfc/rfc9112#section-7.1-3 with pytest.raises( http_exceptions.BadHttpMessage, match="duplicate `chunked` Transfer-Encoding|nvalid `Transfer-Encoding`", ): parser.feed_data(text) def test_conn_upgrade(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" b"connection: upgrade\r\n" b"upgrade: websocket\r\n\r\n" ) messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close assert msg.upgrade assert upgrade def test_conn_upgrade_comma_list(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" b"connection: keep-alive, upgrade\r\n" b"upgrade: websocket\r\n\r\n" ) messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close assert msg.upgrade assert upgrade def test_conn_upgrade_multiple_headers(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" b"connection: keep-alive\r\n" b"connection: upgrade\r\n" b"upgrade: websocket\r\n\r\n" ) messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close assert msg.upgrade assert upgrade def test_bad_upgrade(parser: HttpRequestParser) -> None: """Test not upgraded if missing Upgrade header.""" text = b"GET /test HTTP/1.1\r\nconnection: upgrade\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.upgrade assert not upgrade def test_compression_empty(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-encoding: \r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression is None def test_compression_deflate(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-encoding: deflate\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression == "deflate" def test_compression_gzip(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-encoding: gzip\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression == "gzip" @pytest.mark.skipif(brotli is None, reason="brotli is not installed") def test_compression_brotli(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-encoding: br\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression == "br" @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") def test_compression_zstd(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-encoding: zstd\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression == "zstd" @pytest.mark.parametrize( "enc", ( "zstd".encode(), # "st".upper() == "ST" "deflate".encode(), # "fl".upper() == "FL" ), ) def test_compression_non_ascii(parser: HttpRequestParser, enc: bytes) -> None: text = b"GET /test HTTP/1.1\r\ncontent-encoding: " + enc + b"\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] # Non-ascii input should not evaluate to a valid encoding scheme. assert msg.compression is None def test_compression_unknown(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-encoding: compress\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.compression is None def test_url_connect(parser: HttpRequestParser) -> None: text = b"CONNECT www.google.com HTTP/1.1\r\ncontent-length: 0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert upgrade assert msg.url == URL.build(authority="www.google.com") def test_headers_connect(parser: HttpRequestParser) -> None: text = b"CONNECT www.google.com HTTP/1.1\r\ncontent-length: 0\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert upgrade assert isinstance(payload, streams.StreamReader) def test_url_absolute(parser: HttpRequestParser) -> None: text = ( b"GET https://www.google.com/path/to.html HTTP/1.1\r\n" b"content-length: 0\r\n\r\n" ) messages, upgrade, tail = parser.feed_data(text) msg, payload = messages[0] assert not upgrade assert msg.method == "GET" assert msg.url == URL("https://www.google.com/path/to.html") def test_headers_old_websocket_key1(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nSEC-WEBSOCKET-KEY1: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_headers_content_length_err_1(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-length: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_headers_content_length_err_2(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-length: -1\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) _pad: dict[bytes, str] = { b"": "empty", # not a typo. Python likes triple zero b"\000": "NUL", b" ": "SP", b" ": "SPSP", # not a typo: both 0xa0 and 0x0a in case of 8-bit fun b"\n": "LF", b"\xa0": "NBSP", b"\t ": "TABSP", } @pytest.mark.parametrize("hdr", [b"", b"foo"], ids=["name-empty", "with-name"]) @pytest.mark.parametrize("pad2", _pad.keys(), ids=["post-" + n for n in _pad.values()]) @pytest.mark.parametrize("pad1", _pad.keys(), ids=["pre-" + n for n in _pad.values()]) def test_invalid_header_spacing( parser: HttpRequestParser, pad1: bytes, pad2: bytes, hdr: bytes ) -> None: text = b"GET /test HTTP/1.1\r\n%s%s%s: value\r\n\r\n" % (pad1, hdr, pad2) if pad1 == pad2 == b"" and hdr != b"": # one entry in param matrix is correct: non-empty name, not padded parser.feed_data(text) return with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_empty_header_name(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\n:test\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_invalid_header(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ntest line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) def test_invalid_name(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ntest[]: line\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(text) @pytest.mark.parametrize("size", [40960, 8191]) def test_max_header_field_size(parser: HttpRequestParser, size: int) -> None: name = b"t" * size text = b"GET /test HTTP/1.1\r\n" + name + b":data\r\n\r\n" match = "400, message:\n Got more than 8190 bytes when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): for i in range(0, len(text), 5000): # pragma: no branch parser.feed_data(text[i : i + 5000]) def test_max_header_size_under_limit(parser: HttpRequestParser) -> None: name = b"t" * 8185 text = b"GET /test HTTP/1.1\r\n" + name + b":data\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/test" assert msg.version == (1, 1) assert msg.headers == CIMultiDict({name.decode(): "data"}) assert msg.raw_headers == ((name, b"data"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert msg.url == URL("/test") @pytest.mark.parametrize("size", [40960, 8191]) def test_max_header_value_size(parser: HttpRequestParser, size: int) -> None: name = b"t" * size text = b"GET /test HTTP/1.1\r\ndata:" + name + b"\r\n\r\n" match = "400, message:\n Got more than 8190 bytes when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): for i in range(0, len(text), 4000): # pragma: no branch parser.feed_data(text[i : i + 4000]) def test_max_header_combined_size(parser: HttpRequestParser) -> None: k = b"t" * 4100 text = b"GET /test HTTP/1.1\r\n" + k + b":" + k + b"\r\n\r\n" match = "400, message:\n Got more than 8190 bytes when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(text) @pytest.mark.parametrize("size", [40960, 8191]) async def test_max_trailer_size(parser: HttpRequestParser, size: int) -> None: value = b"t" * size text = ( b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" + hex(4000)[2:].encode() + b"\r\n" + b"b" * 4000 + b"\r\n0\r\ntest: " + value + b"\r\n\r\n" ) match = "400, message:\n Got more than 8190 bytes when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): payload = None for i in range(0, len(text), 3000): # pragma: no branch messages, upgrade, tail = parser.feed_data(text[i : i + 3000]) if messages: payload = messages[0][-1] # Trailers are not seen until payload is read. assert payload is not None await payload.read() @pytest.mark.parametrize("headers,trailers", ((129, 0), (0, 129), (64, 65))) async def test_max_headers( parser: HttpRequestParser, headers: int, trailers: int ) -> None: text = ( b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked" + b"".join(b"\r\nHeader-%d: Value" % i for i in range(headers)) + b"\r\n\r\n4\r\ntest\r\n0" + b"".join(b"\r\nTrailer-%d: Value" % i for i in range(trailers)) + b"\r\n\r\n" ) match = "Too many (headers|trailers) received" with pytest.raises(http_exceptions.BadHttpMessage, match=match): messages, upgrade, tail = parser.feed_data(text) # Trailers are not seen until payload is read. await messages[0][-1].read() def test_max_header_value_size_under_limit(parser: HttpRequestParser) -> None: value = b"A" * 8185 text = b"GET /test HTTP/1.1\r\ndata:" + value + b"\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/test" assert msg.version == (1, 1) assert msg.headers == CIMultiDict({"data": value.decode()}) assert msg.raw_headers == ((b"data", value),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert msg.url == URL("/test") @pytest.mark.parametrize("size", [40965, 8191]) def test_max_header_value_size_continuation( response: HttpResponseParser, size: int ) -> None: name = b"T" * (size - 5) text = b"HTTP/1.1 200 Ok\r\ndata: test\r\n " + name + b"\r\n\r\n" match = "400, message:\n Got more than 8190 bytes when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): for i in range(0, len(text), 9000): # pragma: no branch response.feed_data(text[i : i + 9000]) def test_max_header_value_size_continuation_under_limit( response: HttpResponseParser, ) -> None: value = b"A" * 8179 text = b"HTTP/1.1 200 Ok\r\ndata: test\r\n " + value + b"\r\n\r\n" messages, upgrade, tail = response.feed_data(text) msg = messages[0][0] assert msg.code == 200 assert msg.reason == "Ok" assert msg.version == (1, 1) assert msg.headers == CIMultiDict({"data": "test " + value.decode()}) assert msg.raw_headers == ((b"data", b"test " + value),) assert msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked def test_http_request_parser(parser: HttpRequestParser) -> None: text = b"GET /path HTTP/1.1\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/path" assert msg.version == (1, 1) assert msg.headers == CIMultiDict() assert msg.raw_headers == () assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert msg.url == URL("/path") def test_http_request_bad_status_line(parser: HttpRequestParser) -> None: text = b"getpath \r\n\r\n" with pytest.raises(http_exceptions.BadStatusLine) as exc_info: parser.feed_data(text) # Check for accidentally escaped message. assert r"\n" not in exc_info.value.message _num: dict[bytes, str] = { # dangerous: accepted by Python int() # unicodedata.category("\U0001D7D9") == 'Nd' "\N{MATHEMATICAL DOUBLE-STRUCK DIGIT ONE}".encode(): "utf8digit", # only added for interop tests, refused by Python int() # unicodedata.category("\U000000B9") == 'No' "\N{SUPERSCRIPT ONE}".encode(): "utf8number", "\N{SUPERSCRIPT ONE}".encode("latin-1"): "latin1number", } @pytest.mark.parametrize("nonascii_digit", _num.keys(), ids=_num.values()) def test_http_request_bad_status_line_number( parser: HttpRequestParser, nonascii_digit: bytes ) -> None: text = b"GET /digit HTTP/1." + nonascii_digit + b"\r\n\r\n" with pytest.raises(http_exceptions.BadStatusLine): parser.feed_data(text) def test_http_request_bad_status_line_separator(parser: HttpRequestParser) -> None: # single code point, old, multibyte NFKC, multibyte NFKD utf8sep = "\N{ARABIC LIGATURE SALLALLAHOU ALAYHE WASALLAM}".encode() text = b"GET /ligature HTTP/1" + utf8sep + b"1\r\n\r\n" with pytest.raises(http_exceptions.BadStatusLine): parser.feed_data(text) def test_http_request_bad_status_line_whitespace(parser: HttpRequestParser) -> None: text = b"GET\n/path\fHTTP/1.1\r\n\r\n" with pytest.raises(http_exceptions.BadStatusLine): parser.feed_data(text) def test_http_request_message_after_close(parser: HttpRequestParser) -> None: text = b"GET / HTTP/1.1\r\nConnection: close\r\n\r\nInvalid\r\n\r\n" with pytest.raises( http_exceptions.BadHttpMessage, match="Data after `Connection: close`" ): parser.feed_data(text) def test_http_request_message_after_close_comma_list(parser: HttpRequestParser) -> None: text = b"GET / HTTP/1.1\r\nConnection: close, keep-alive\r\n\r\nInvalid\r\n\r\n" with pytest.raises( http_exceptions.BadHttpMessage, match="Data after `Connection: close`" ): parser.feed_data(text) def test_http_request_upgrade(parser: HttpRequestParser) -> None: text = ( b"GET /test HTTP/1.1\r\n" b"connection: upgrade\r\n" b"upgrade: websocket\r\n\r\n" b"some raw data" ) messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close assert msg.upgrade assert upgrade assert tail == b"some raw data" async def test_http_request_upgrade_unknown(parser: HttpRequestParser) -> None: text = ( b"POST / HTTP/1.1\r\n" b"Connection: Upgrade\r\n" b"Content-Length: 2\r\n" b"Upgrade: unknown\r\n" b"Content-Type: application/json\r\n\r\n" b"{}" ) messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert not msg.should_close assert msg.upgrade assert not upgrade assert not msg.chunked assert tail == b"" assert await messages[0][-1].read() == b"{}" @pytest.fixture def xfail_c_parser_url(request: pytest.FixtureRequest) -> None: if isinstance(request.getfixturevalue("parser"), HttpRequestParserPy): return request.node.add_marker( pytest.mark.xfail( reason="Regression test for Py parser. May match C behaviour later.", raises=http_exceptions.InvalidURLError, ) ) @pytest.mark.usefixtures("xfail_c_parser_url") def test_http_request_parser_utf8_request_line(parser: HttpRequestParser) -> None: messages, upgrade, tail = parser.feed_data( # note the truncated unicode sequence b"GET /P\xc3\xbcnktchen\xa0\xef\xb7 HTTP/1.1\r\n" + # for easier grep: ASCII 0xA0 more commonly known as non-breaking space # note the leading and trailing spaces "sTeP: \N{LATIN SMALL LETTER SHARP S}nek\t\N{NO-BREAK SPACE} " "\r\n\r\n".encode() ) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/Pünktchen\udca0\udcef\udcb7" assert msg.version == (1, 1) assert msg.headers == CIMultiDict([("STEP", "ßnek\t\xa0")]) assert msg.raw_headers == ((b"sTeP", "ßnek\t\xa0".encode()),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked # python HTTP parser depends on Cython and CPython URL to match # .. but yarl.URL("/abs") is not equal to URL.build(path="/abs"), see #6409 assert msg.url == URL.build(path="/Pünktchen\udca0\udcef\udcb7", encoded=True) def test_http_request_parser_utf8(parser: HttpRequestParser) -> None: text = "GET /path HTTP/1.1\r\nx-test:тест\r\n\r\n".encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/path" assert msg.version == (1, 1) assert msg.headers == CIMultiDict([("X-TEST", "тест")]) assert msg.raw_headers == ((b"x-test", "тест".encode()),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert msg.url == URL("/path") def test_http_request_parser_non_utf8(parser: HttpRequestParser) -> None: text = "GET /path HTTP/1.1\r\nx-test:тест\r\n\r\n".encode("cp1251") msg = parser.feed_data(text)[0][0][0] assert msg.method == "GET" assert msg.path == "/path" assert msg.version == (1, 1) assert msg.headers == CIMultiDict( [("X-TEST", "тест".encode("cp1251").decode("utf8", "surrogateescape"))] ) assert msg.raw_headers == ((b"x-test", "тест".encode("cp1251")),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert msg.url == URL("/path") def test_http_request_parser_two_slashes(parser: HttpRequestParser) -> None: text = b"GET //path HTTP/1.1\r\n\r\n" msg = parser.feed_data(text)[0][0][0] assert msg.method == "GET" assert msg.path == "//path" assert msg.url.path == "//path" assert msg.version == (1, 1) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked @pytest.mark.parametrize( "rfc9110_5_6_2_token_delim", [bytes([i]) for i in rb'"(),/:;<=>?@[\]{}'], ) def test_http_request_parser_bad_method( parser: HttpRequestParser, rfc9110_5_6_2_token_delim: bytes ) -> None: with pytest.raises(http_exceptions.BadHttpMethod): parser.feed_data(rfc9110_5_6_2_token_delim + b'ET" /get HTTP/1.1\r\n\r\n') def test_http_request_parser_bad_version(parser: HttpRequestParser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(b"GET //get HT/11\r\n\r\n") def test_http_request_parser_bad_version_number(parser: HttpRequestParser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): parser.feed_data(b"GET /test HTTP/1.32\r\n\r\n") def test_http_request_parser_bad_ascii_uri(parser: HttpRequestParser) -> None: with pytest.raises(http_exceptions.InvalidURLError): parser.feed_data(b"GET ! HTTP/1.1\r\n\r\n") def test_http_request_parser_bad_nonascii_uri(parser: HttpRequestParser) -> None: with pytest.raises(http_exceptions.InvalidURLError): parser.feed_data(b"GET \xff HTTP/1.1\r\n\r\n") @pytest.mark.parametrize("size", [40965, 8191]) def test_http_request_max_status_line(parser: HttpRequestParser, size: int) -> None: path = b"t" * (size - 5) match = "400, message:\n Got more than 8190 bytes when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(b"GET /path" + path + b" HTTP/1.1\r\n\r\n") def test_http_request_max_status_line_under_limit(parser: HttpRequestParser) -> None: path = b"t" * 8172 messages, upgraded, tail = parser.feed_data( b"GET /path" + path + b" HTTP/1.1\r\n\r\n" ) msg = messages[0][0] assert msg.method == "GET" assert msg.path == "/path" + path.decode() assert msg.version == (1, 1) assert msg.headers == CIMultiDict() assert msg.raw_headers == () assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert msg.url == URL("/path" + path.decode()) def test_http_response_parser_utf8(response: HttpResponseParser) -> None: text = "HTTP/1.1 200 Ok\r\nx-test:тест\r\n\r\n".encode() messages, upgraded, tail = response.feed_data(text) assert len(messages) == 1 msg = messages[0][0] assert msg.version == (1, 1) assert msg.code == 200 assert msg.reason == "Ok" assert msg.headers == CIMultiDict([("X-TEST", "тест")]) assert msg.raw_headers == ((b"x-test", "тест".encode()),) assert not upgraded assert not tail def test_http_response_parser_utf8_without_reason(response: HttpResponseParser) -> None: text = "HTTP/1.1 200 \r\nx-test:тест\r\n\r\n".encode() messages, upgraded, tail = response.feed_data(text) assert len(messages) == 1 msg = messages[0][0] assert msg.version == (1, 1) assert msg.code == 200 assert msg.reason == "" assert msg.headers == CIMultiDict([("X-TEST", "тест")]) assert msg.raw_headers == ((b"x-test", "тест".encode()),) assert not upgraded assert not tail def test_http_response_parser_obs_line_folding(response: HttpResponseParser) -> None: text = b"HTTP/1.1 200 Ok\r\ntest: line\r\n continue\r\n\r\n" messages, upgraded, tail = response.feed_data(text) assert len(messages) == 1 msg = messages[0][0] assert msg.version == (1, 1) assert msg.code == 200 assert msg.reason == "Ok" assert msg.headers == CIMultiDict([("TEST", "line continue")]) assert msg.raw_headers == ((b"test", b"line continue"),) assert not upgraded assert not tail @pytest.mark.dev_mode def test_http_response_parser_strict_obs_line_folding( response: HttpResponseParser, ) -> None: text = b"HTTP/1.1 200 Ok\r\ntest: line\r\n continue\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage): response.feed_data(text) @pytest.mark.parametrize("size", [40962, 8191]) def test_http_response_parser_bad_status_line_too_long( response: HttpResponseParser, size: int ) -> None: reason = b"t" * (size - 2) match = "400, message:\n Got more than 8190 bytes when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): response.feed_data(b"HTTP/1.1 200 Ok" + reason + b"\r\n\r\n") def test_http_response_parser_status_line_under_limit( response: HttpResponseParser, ) -> None: reason = b"O" * 8177 messages, upgraded, tail = response.feed_data( b"HTTP/1.1 200 " + reason + b"\r\n\r\n" ) msg = messages[0][0] assert msg.version == (1, 1) assert msg.code == 200 assert msg.reason == reason.decode() def test_http_response_parser_bad_version(response: HttpResponseParser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): response.feed_data(b"HT/11 200 Ok\r\n\r\n") def test_http_response_parser_bad_version_number(response: HttpResponseParser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): response.feed_data(b"HTTP/12.3 200 Ok\r\n\r\n") def test_http_response_parser_no_reason(response: HttpResponseParser) -> None: msg = response.feed_data(b"HTTP/1.1 200\r\n\r\n")[0][0][0] assert msg.version == (1, 1) assert msg.code == 200 assert msg.reason == "" def test_http_response_parser_lenient_headers(response: HttpResponseParser) -> None: messages, upgrade, tail = response.feed_data( b"HTTP/1.1 200 test\r\nFoo: abc\x01def\r\n\r\n" ) msg = messages[0][0] assert msg.headers["Foo"] == "abc\x01def" @pytest.mark.dev_mode def test_http_response_parser_strict_headers(response: HttpResponseParser) -> None: if isinstance(response, HttpResponseParserPy): pytest.xfail("Py parser is lenient. May update py-parser later.") with pytest.raises(http_exceptions.BadHttpMessage): # type: ignore[unreachable] response.feed_data(b"HTTP/1.1 200 test\r\nFoo: abc\x01def\r\n\r\n") def test_http_response_parser_null_byte_in_header_value( response: HttpResponseParser, ) -> None: with pytest.raises(http_exceptions.InvalidHeader): response.feed_data(b"HTTP/1.1 200 OK\r\nFoo: abc\x00def\r\n\r\n") def test_http_response_parser_bad_crlf(response: HttpResponseParser) -> None: """Still a lot of dodgy servers sending bad requests like this.""" messages, upgrade, tail = response.feed_data( b"HTTP/1.0 200 OK\nFoo: abc\nBar: def\n\nBODY\n" ) msg = messages[0][0] assert msg.headers["Foo"] == "abc" assert msg.headers["Bar"] == "def" async def test_http_response_parser_bad_chunked_lax( response: HttpResponseParser, ) -> None: text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) messages, upgrade, tail = response.feed_data(text) assert await messages[0][1].read(5) == b"abcde" @pytest.mark.dev_mode async def test_http_response_parser_bad_chunked_strict_py( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol ) -> None: response = HttpResponseParserPy( protocol, loop, 2**16, max_line_size=8190, max_field_size=8190, ) text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) with pytest.raises(http_exceptions.TransferEncodingError, match="5"): response.feed_data(text) @pytest.mark.dev_mode @pytest.mark.skipif( "HttpRequestParserC" not in dir(aiohttp.http_parser), reason="C based HTTP parser not available", ) async def test_http_response_parser_bad_chunked_strict_c( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol ) -> None: response = HttpResponseParserC( protocol, loop, 2**16, max_line_size=8190, max_field_size=8190, ) text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) with pytest.raises(http_exceptions.BadHttpMessage): response.feed_data(text) async def test_http_response_parser_notchunked( response: HttpResponseParser, ) -> None: text = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: notchunked\r\n\r\n1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" messages, upgrade, tail = response.feed_data(text) response.feed_eof() # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.2 assert await messages[0][1].read() == b"1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" async def test_http_response_parser_last_chunked( response: HttpResponseParser, ) -> None: text = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: not, chunked\r\n\r\n1\r\nT\r\n3\r\nest\r\n0\r\n\r\n" messages, upgrade, tail = response.feed_data(text) # https://www.rfc-editor.org/rfc/rfc9112#section-6.3-2.4.2 assert await messages[0][1].read() == b"Test" def test_http_response_parser_bad(response: HttpResponseParser) -> None: with pytest.raises(http_exceptions.BadHttpMessage): response.feed_data(b"HTT/1\r\n\r\n") def test_http_response_parser_code_under_100(response: HttpResponseParser) -> None: with pytest.raises(http_exceptions.BadStatusLine): response.feed_data(b"HTTP/1.1 99 test\r\n\r\n") def test_http_response_parser_code_above_999(response: HttpResponseParser) -> None: with pytest.raises(http_exceptions.BadStatusLine): response.feed_data(b"HTTP/1.1 9999 test\r\n\r\n") def test_http_response_parser_code_not_int(response: HttpResponseParser) -> None: with pytest.raises(http_exceptions.BadStatusLine): response.feed_data(b"HTTP/1.1 ttt test\r\n\r\n") @pytest.mark.parametrize("nonascii_digit", _num.keys(), ids=_num.values()) def test_http_response_parser_code_not_ascii( response: HttpResponseParser, nonascii_digit: bytes ) -> None: with pytest.raises(http_exceptions.BadStatusLine): response.feed_data(b"HTTP/1.1 20" + nonascii_digit + b" test\r\n\r\n") def test_http_request_chunked_payload(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert msg.chunked assert not payload.is_eof() assert isinstance(payload, streams.StreamReader) parser.feed_data(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n") assert b"dataline" == b"".join(d for d in payload._buffer) assert payload._http_chunk_splits is not None assert [4, 8] == list(payload._http_chunk_splits) assert payload.is_eof() def test_http_request_chunked_payload_and_next_message( parser: HttpRequestParser, ) -> None: text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] messages, upgraded, tail = parser.feed_data( b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n" b"POST /test2 HTTP/1.1\r\n" b"transfer-encoding: chunked\r\n\r\n" ) assert b"dataline" == b"".join(d for d in payload._buffer) assert payload._http_chunk_splits is not None assert [4, 8] == list(payload._http_chunk_splits) assert payload.is_eof() assert len(messages) == 1 msg2, payload2 = messages[0] assert msg2.method == "POST" assert msg2.chunked assert not payload2.is_eof() def test_http_request_chunked_payload_chunks(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] parser.feed_data(b"4\r\ndata\r") parser.feed_data(b"\n4") parser.feed_data(b"\r") parser.feed_data(b"\n") parser.feed_data(b"li") parser.feed_data(b"ne\r\n0\r\n") parser.feed_data(b"test: test\r\n") assert b"dataline" == b"".join(d for d in payload._buffer) assert payload._http_chunk_splits is not None assert [4, 8] == list(payload._http_chunk_splits) assert not payload.is_eof() parser.feed_data(b"\r\n") assert b"dataline" == b"".join(d for d in payload._buffer) assert [4, 8] == list(payload._http_chunk_splits) assert payload.is_eof() def test_parse_chunked_payload_chunk_extension(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] parser.feed_data(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest: test\r\n\r\n") assert b"dataline" == b"".join(d for d in payload._buffer) assert payload._http_chunk_splits is not None assert [4, 8] == list(payload._http_chunk_splits) assert payload.is_eof() async def test_request_chunked_with_trailer(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\ntest: trailer\r\nsecond: test trailer\r\n\r\n" messages, upgraded, tail = parser.feed_data(text) assert not tail msg, payload = messages[0] assert await payload.read() == b"test" # TODO: Add assertion of trailers when API added. async def test_request_chunked_reject_bad_trailer(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nbad\ntrailer\r\n\r\n" with pytest.raises(http_exceptions.BadHttpMessage, match=r"b'bad\\ntrailer'"): parser.feed_data(text) def test_parse_no_length_or_te_on_post( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol, request_cls: type[HttpRequestParser], ) -> None: parser = request_cls(protocol, loop, limit=2**16) text = b"POST /test HTTP/1.1\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert payload.is_eof() def test_parse_payload_response_without_body( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol, response_cls: type[HttpResponseParser], ) -> None: parser = response_cls(protocol, loop, 2**16, response_with_body=False) text = b"HTTP/1.1 200 Ok\r\ncontent-length: 10\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert payload.is_eof() def test_parse_length_payload(response: HttpResponseParser) -> None: text = b"HTTP/1.1 200 Ok\r\ncontent-length: 4\r\n\r\n" msg, payload = response.feed_data(text)[0][0] assert not payload.is_eof() response.feed_data(b"da") response.feed_data(b"t") response.feed_data(b"aHT") assert payload.is_eof() assert b"data" == b"".join(d for d in payload._buffer) def test_parse_no_length_payload(parser: HttpRequestParser) -> None: text = b"PUT / HTTP/1.1\r\n\r\n" msg, payload = parser.feed_data(text)[0][0] assert payload.is_eof() def test_parse_content_length_payload_multiple(response: HttpResponseParser) -> None: text = b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\nfirst" msg, payload = response.feed_data(text)[0][0] assert msg.version == HttpVersion(major=1, minor=1) assert msg.code == 200 assert msg.reason == "OK" assert msg.headers == CIMultiDict( [ ("Content-Length", "5"), ] ) assert msg.raw_headers == ((b"content-length", b"5"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert payload.is_eof() assert b"first" == b"".join(d for d in payload._buffer) text = b"HTTP/1.1 200 OK\r\ncontent-length: 6\r\n\r\nsecond" msg, payload = response.feed_data(text)[0][0] assert msg.version == HttpVersion(major=1, minor=1) assert msg.code == 200 assert msg.reason == "OK" assert msg.headers == CIMultiDict( [ ("Content-Length", "6"), ] ) assert msg.raw_headers == ((b"content-length", b"6"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert payload.is_eof() assert b"second" == b"".join(d for d in payload._buffer) def test_parse_content_length_than_chunked_payload( response: HttpResponseParser, ) -> None: text = b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\nfirst" msg, payload = response.feed_data(text)[0][0] assert msg.version == HttpVersion(major=1, minor=1) assert msg.code == 200 assert msg.reason == "OK" assert msg.headers == CIMultiDict( [ ("Content-Length", "5"), ] ) assert msg.raw_headers == ((b"content-length", b"5"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert not msg.chunked assert payload.is_eof() assert b"first" == b"".join(d for d in payload._buffer) text = ( b"HTTP/1.1 200 OK\r\n" b"transfer-encoding: chunked\r\n\r\n" b"6\r\nsecond\r\n0\r\n\r\n" ) msg, payload = response.feed_data(text)[0][0] assert msg.version == HttpVersion(major=1, minor=1) assert msg.code == 200 assert msg.reason == "OK" assert msg.headers == CIMultiDict( [ ("Transfer-Encoding", "chunked"), ] ) assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert msg.chunked assert payload.is_eof() assert b"second" == b"".join(d for d in payload._buffer) @pytest.mark.parametrize("code", (204, 304, 101, 102)) def test_parse_chunked_payload_empty_body_than_another_chunked( response: HttpResponseParser, code: int ) -> None: head = f"HTTP/1.1 {code} OK\r\n".encode() text = head + b"transfer-encoding: chunked\r\n\r\n" msg, payload = response.feed_data(text)[0][0] assert msg.version == HttpVersion(major=1, minor=1) assert msg.code == code assert msg.reason == "OK" assert msg.headers == CIMultiDict( [ ("Transfer-Encoding", "chunked"), ] ) assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert msg.chunked assert payload.is_eof() text = ( b"HTTP/1.1 200 OK\r\n" b"transfer-encoding: chunked\r\n\r\n" b"6\r\nsecond\r\n0\r\n\r\n" ) msg, payload = response.feed_data(text)[0][0] assert msg.version == HttpVersion(major=1, minor=1) assert msg.code == 200 assert msg.reason == "OK" assert msg.headers == CIMultiDict( [ ("Transfer-Encoding", "chunked"), ] ) assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),) assert not msg.should_close assert msg.compression is None assert not msg.upgrade assert msg.chunked assert payload.is_eof() assert b"second" == b"".join(d for d in payload._buffer) async def test_parse_chunked_payload_split_chunks(response: HttpResponseParser) -> None: network_chunks = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n", b"5\r\nfi", b"rst", # This simulates a bug in lax mode caused when the \r\n separator, before the # next HTTP chunk, appears at the start of the next network chunk. b"\r\n", b"6", b"\r", b"\n", b"second\r", b"\n0\r\n\r\n", ) reader = response.feed_data(network_chunks[0])[0][0][1] for c in network_chunks[1:]: response.feed_data(c) assert response.feed_eof() is None assert reader.is_eof() assert await reader.read() == b"firstsecond" async def test_parse_chunked_payload_with_lf_in_extensions( parser: HttpRequestParser, ) -> None: """Test chunked payload that has a LF in the chunk extensions.""" payload = ( b"GET / HTTP/1.1\r\nHost: localhost:5001\r\n" b"Transfer-Encoding: chunked\r\n\r\n2;\nxx\r\n4c\r\n0\r\n\r\n" b"GET /admin HTTP/1.1\r\nHost: localhost:5001\r\n" b"Transfer-Encoding: chunked\r\n\r\n0\r\n\r\n" ) with pytest.raises(http_exceptions.BadHttpMessage, match="\\\\nxx"): parser.feed_data(payload) def test_partial_url(parser: HttpRequestParser) -> None: messages, upgrade, tail = parser.feed_data(b"GET /te") assert len(messages) == 0 messages, upgrade, tail = parser.feed_data(b"st HTTP/1.1\r\n\r\n") assert len(messages) == 1 msg, payload = messages[0] assert msg.method == "GET" assert msg.path == "/test" assert msg.version == (1, 1) assert payload.is_eof() @pytest.mark.parametrize( ("uri", "path", "query", "fragment"), [ ("/path%23frag", "/path#frag", {}, ""), ("/path%2523frag", "/path%23frag", {}, ""), ("/path?key=value%23frag", "/path", {"key": "value#frag"}, ""), ("/path?key=value%2523frag", "/path", {"key": "value%23frag"}, ""), ("/path#frag%20", "/path", {}, "frag "), ("/path#frag%2520", "/path", {}, "frag%20"), ], ) def test_parse_uri_percent_encoded( parser: HttpRequestParser, uri: str, path: str, query: dict[str, str], fragment: str ) -> None: text = (f"GET {uri} HTTP/1.1\r\n\r\n").encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.path == uri assert msg.url == URL(uri) assert msg.url.path == path assert msg.url.query == query assert msg.url.fragment == fragment def test_parse_uri_utf8(parser: HttpRequestParser) -> None: if not isinstance(parser, HttpRequestParserPy): pytest.xfail("Not valid HTTP. Maybe update py-parser to reject later.") text = ("GET /путь?ключ=знач#фраг HTTP/1.1\r\n\r\n").encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.path == "/путь?ключ=знач#фраг" assert msg.url.path == "/путь" assert msg.url.query == {"ключ": "знач"} assert msg.url.fragment == "фраг" def test_parse_uri_utf8_percent_encoded(parser: HttpRequestParser) -> None: text = ( "GET %s HTTP/1.1\r\n\r\n" % quote("/путь?ключ=знач#фраг", safe="/?=#") ).encode() messages, upgrade, tail = parser.feed_data(text) msg = messages[0][0] assert msg.path == quote("/путь?ключ=знач#фраг", safe="/?=#") assert msg.url == URL("/путь?ключ=знач#фраг") assert msg.url.path == "/путь" assert msg.url.query == {"ключ": "знач"} assert msg.url.fragment == "фраг" @pytest.mark.skipif( "HttpRequestParserC" not in dir(aiohttp.http_parser), reason="C based HTTP parser not available", ) def test_parse_bad_method_for_c_parser_raises( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol ) -> None: payload = b"GET1 /test HTTP/1.1\r\n\r\n" parser = HttpRequestParserC( protocol, loop, 2**16, max_line_size=8190, max_headers=128, max_field_size=8190, ) with pytest.raises(aiohttp.http_exceptions.BadStatusLine): messages, upgrade, tail = parser.feed_data(payload) class TestParsePayload: async def test_parse_eof_payload(self, protocol: BaseProtocol) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, headers_parser=HeadersParser()) p.feed_data(b"data") p.feed_eof() assert out.is_eof() assert [bytearray(b"data")] == list(out._buffer) async def test_parse_length_payload_eof(self, protocol: BaseProtocol) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, length=4, headers_parser=HeadersParser()) p.feed_data(b"da") with pytest.raises(http_exceptions.ContentLengthError): p.feed_eof() async def test_parse_chunked_payload_size_error( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) with pytest.raises(http_exceptions.TransferEncodingError): p.feed_data(b"blah\r\n") assert isinstance(out.exception(), http_exceptions.TransferEncodingError) async def test_parse_chunked_payload_size_data_mismatch( self, protocol: BaseProtocol ) -> None: """Chunk-size does not match actual data: should raise, not hang. Regression test for #10596. """ out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) # Declared chunk-size is 4 but actual data is "Hello" (5 bytes). # After consuming 4 bytes, remaining starts with "o" not "\r\n". with pytest.raises(http_exceptions.TransferEncodingError): p.feed_data(b"4\r\nHello\r\n0\r\n\r\n") assert isinstance(out.exception(), http_exceptions.TransferEncodingError) async def test_parse_chunked_payload_size_data_mismatch_too_short( self, protocol: BaseProtocol ) -> None: """Chunk-size larger than data: declared 6 but only 5 bytes before CRLF. Regression test for #10596. """ out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) # Declared chunk-size is 6 but actual data before CRLF is "Hello" (5 bytes). # Parser reads 6 bytes: "Hello\r", then expects \r\n but sees "\n0\r\n..." with pytest.raises(http_exceptions.TransferEncodingError): p.feed_data(b"6\r\nHello\r\n0\r\n\r\n") assert isinstance(out.exception(), http_exceptions.TransferEncodingError) async def test_parse_chunked_payload_split_end( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"\r\n") assert out.is_eof() assert b"asdf" == b"".join(out._buffer) async def test_parse_chunked_payload_split_end2( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n\r") p.feed_data(b"\n") assert out.is_eof() assert b"asdf" == b"".join(out._buffer) async def test_parse_chunked_payload_split_end_trailers( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n") p.feed_data(b"\r\n") assert out.is_eof() assert b"asdf" == b"".join(out._buffer) async def test_parse_chunked_payload_split_end_trailers2( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r") p.feed_data(b"\n") assert out.is_eof() assert b"asdf" == b"".join(out._buffer) async def test_parse_chunked_payload_split_end_trailers3( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\nContent-MD5: ") p.feed_data(b"912ec803b2ce49e4a541068d495ab570\r\n\r\n") assert out.is_eof() assert b"asdf" == b"".join(out._buffer) async def test_parse_chunked_payload_split_end_trailers4( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\nC") p.feed_data(b"ontent-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r\n") assert out.is_eof() assert b"asdf" == b"".join(out._buffer) async def test_http_payload_parser_length(self, protocol: BaseProtocol) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, length=2, headers_parser=HeadersParser()) eof, tail = p.feed_data(b"1245") assert eof assert b"12" == out._buffer[0] assert b"45" == tail async def test_http_payload_parser_deflate(self, protocol: BaseProtocol) -> None: # c=compressobj(wbits=15); b''.join([c.compress(b'data'), c.flush()]) COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" length = len(COMPRESSED) out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=length, compression="deflate", headers_parser=HeadersParser() ) p.feed_data(COMPRESSED) assert b"data" == out._buffer[0] assert out.is_eof() async def test_http_payload_parser_deflate_no_hdrs( self, protocol: BaseProtocol ) -> None: """Tests incorrectly formed data (no zlib headers).""" # c=compressobj(wbits=-15); b''.join([c.compress(b'data'), c.flush()]) COMPRESSED = b"KI,I\x04\x00" length = len(COMPRESSED) out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=length, compression="deflate", headers_parser=HeadersParser() ) p.feed_data(COMPRESSED) assert b"data" == out._buffer[0] assert out.is_eof() async def test_http_payload_parser_deflate_light( self, protocol: BaseProtocol ) -> None: # c=compressobj(wbits=9); b''.join([c.compress(b'data'), c.flush()]) COMPRESSED = b"\x18\x95KI,I\x04\x00\x04\x00\x01\x9b" length = len(COMPRESSED) out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=length, compression="deflate", headers_parser=HeadersParser() ) p.feed_data(COMPRESSED) assert b"data" == out._buffer[0] assert out.is_eof() async def test_http_payload_parser_deflate_split( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, compression="deflate", headers_parser=HeadersParser() ) # Feeding one correct byte should be enough to choose exact # deflate decompressor p.feed_data(b"x") p.feed_data(b"\x9cKI,I\x04\x00\x04\x00\x01\x9b") p.feed_eof() assert b"data" == out._buffer[0] async def test_http_payload_parser_deflate_split_err( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, compression="deflate", headers_parser=HeadersParser() ) # Feeding one wrong byte should be enough to choose exact # deflate decompressor p.feed_data(b"K") p.feed_data(b"I,I\x04\x00") p.feed_eof() assert b"data" == out._buffer[0] async def test_http_payload_parser_length_zero( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser(out, length=0, headers_parser=HeadersParser()) assert p.done assert out.is_eof() @pytest.mark.skipif(brotli is None, reason="brotli is not installed") async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None: compressed = brotli.compress(b"brotli data") out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=len(compressed), compression="br", headers_parser=HeadersParser(), ) p.feed_data(compressed) assert b"brotli data" == out._buffer[0] assert out.is_eof() @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None: compressed = zstandard.compress(b"zstd data") out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) p = HttpPayloadParser( out, length=len(compressed), compression="zstd", headers_parser=HeadersParser(), ) p.feed_data(compressed) assert b"zstd data" == out._buffer[0] assert out.is_eof() class TestDeflateBuffer: async def test_feed_data(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() dbuf.decompressor.decompress_sync.return_value = b"line" # First byte should be b'x' in order code not to change the decoder. dbuf.feed_data(b"xxxx") assert [b"line"] == list(buf._buffer) async def test_feed_data_err(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") exc = ValueError() dbuf.decompressor = mock.Mock() dbuf.decompressor.decompress_sync.side_effect = exc with pytest.raises(http_exceptions.ContentEncodingError): # Should be more than 4 bytes to trigger deflate FSM error. # Should start with b'x', otherwise code switch mocked decoder. dbuf.feed_data(b"xsomedata") async def test_feed_eof(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() dbuf.decompressor.flush.return_value = b"line" dbuf.feed_eof() assert [b"line"] == list(buf._buffer) assert buf._eof async def test_feed_eof_err_deflate(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() dbuf.decompressor.flush.return_value = b"line" dbuf.decompressor.eof = False with pytest.raises(http_exceptions.ContentEncodingError): dbuf.feed_eof() async def test_feed_eof_no_err_gzip(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "gzip") dbuf.decompressor = mock.Mock() dbuf.decompressor.flush.return_value = b"line" dbuf.decompressor.eof = False dbuf.feed_eof() assert [b"line"] == list(buf._buffer) async def test_feed_eof_no_err_brotli(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "br") dbuf.decompressor = mock.Mock() dbuf.decompressor.flush.return_value = b"line" dbuf.decompressor.eof = False dbuf.feed_eof() assert [b"line"] == list(buf._buffer) @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") async def test_feed_eof_no_err_zstandard(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "zstd") dbuf.decompressor = mock.Mock() dbuf.decompressor.flush.return_value = b"line" dbuf.decompressor.eof = False dbuf.feed_eof() assert [b"line"] == list(buf._buffer) async def test_empty_body(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") dbuf.feed_eof() assert buf.at_eof() @pytest.mark.parametrize( "chunk_size", [1024, 2**14, 2**16], # 1KB, 16KB, 64KB ids=["1KB", "16KB", "64KB"], ) async def test_streaming_decompress_large_payload( self, protocol: BaseProtocol, chunk_size: int ) -> None: """Test that large payloads decompress correctly when streamed in chunks. This simulates real HTTP streaming where compressed data arrives in small network chunks. Each chunk's decompressed output should be within the max_decompress_size limit, allowing full recovery of the original data. """ # Create a large payload (3MiB) that compresses well original = b"A" * (3 * 2**20) compressed = zlib.compress(original) buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate") # Feed compressed data in chunks (simulating network streaming) for i in range(0, len(compressed), chunk_size): # pragma: no branch chunk = compressed[i : i + chunk_size] dbuf.feed_data(chunk) dbuf.feed_eof() # Read all decompressed data result = b"".join(buf._buffer) assert len(result) == len(original) assert result == original ================================================ FILE: tests/test_http_writer.py ================================================ # Tests for aiohttp/http_writer.py import array import asyncio import zlib from collections.abc import Generator, Iterable from typing import Any from unittest import mock import pytest from multidict import CIMultiDict from aiohttp import ClientConnectionResetError, hdrs, http from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend from aiohttp.http_writer import _serialize_headers @pytest.fixture def enable_writelines() -> Generator[None, None, None]: with mock.patch("aiohttp.http_writer.SKIP_WRITELINES", False): yield @pytest.fixture def disable_writelines() -> Generator[None, None, None]: with mock.patch("aiohttp.http_writer.SKIP_WRITELINES", True): yield @pytest.fixture def force_writelines_small_payloads() -> Generator[None, None, None]: with mock.patch("aiohttp.http_writer.MIN_PAYLOAD_FOR_WRITELINES", 1): yield @pytest.fixture def buf() -> bytearray: return bytearray() @pytest.fixture def transport(buf: bytearray) -> Any: transport = mock.create_autospec(asyncio.Transport, spec_set=True, instance=True) def write(chunk: bytes) -> None: buf.extend(chunk) def writelines(chunks: Iterable[bytes]) -> None: for chunk in chunks: buf.extend(chunk) transport.write.side_effect = write transport.writelines.side_effect = writelines transport.is_closing.return_value = False return transport @pytest.fixture def protocol(loop: asyncio.AbstractEventLoop, transport: asyncio.Transport) -> Any: return mock.create_autospec( BaseProtocol, spec_set=True, instance=True, transport=transport ) def decompress(data: bytes) -> bytes: d = ZLibBackend.decompressobj() return d.decompress(data) def decode_chunked(chunked: bytes | bytearray) -> bytes: i = 0 out = b"" while i < len(chunked): j = chunked.find(b"\r\n", i) assert j != -1, "Malformed chunk" size = int(chunked[i:j], 16) if size == 0: break i = j + 2 out += chunked[i : i + size] i += size + 2 # skip \r\n after the chunk return out def test_payloadwriter_properties( transport: asyncio.Transport, protocol: BaseProtocol, loop: asyncio.AbstractEventLoop, ) -> None: writer = http.StreamWriter(protocol, loop) assert writer.protocol == protocol assert writer.transport == transport async def test_write_headers_buffered_small_payload( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) headers = CIMultiDict({"Content-Length": "11", "Host": "example.com"}) # Write headers - should be buffered await msg.write_headers("GET / HTTP/1.1", headers) assert len(buf) == 0 # Headers not sent yet # Write small body - should coalesce with headers await msg.write(b"Hello World", drain=False) # Verify content assert b"GET / HTTP/1.1\r\n" in buf assert b"Host: example.com\r\n" in buf assert b"Content-Length: 11\r\n" in buf assert b"\r\n\r\nHello World" in buf async def test_write_headers_chunked_coalescing( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() headers = CIMultiDict({"Transfer-Encoding": "chunked", "Host": "example.com"}) # Write headers - should be buffered await msg.write_headers("POST /upload HTTP/1.1", headers) assert len(buf) == 0 # Headers not sent yet # Write first chunk - should coalesce with headers await msg.write(b"First chunk", drain=False) # Verify content assert b"POST /upload HTTP/1.1\r\n" in buf assert b"Transfer-Encoding: chunked\r\n" in buf # "b" is hex for 11 (length of "First chunk") assert b"\r\n\r\nb\r\nFirst chunk\r\n" in buf async def test_write_eof_with_buffered_headers( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) headers = CIMultiDict({"Content-Length": "9", "Host": "example.com"}) # Write headers - should be buffered await msg.write_headers("POST /data HTTP/1.1", headers) assert len(buf) == 0 # Call write_eof with body - should coalesce await msg.write_eof(b"Last data") # Verify content assert b"POST /data HTTP/1.1\r\n" in buf assert b"\r\n\r\nLast data" in buf async def test_set_eof_sends_buffered_headers( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) headers = CIMultiDict({"Host": "example.com"}) # Write headers - should be buffered await msg.write_headers("GET /empty HTTP/1.1", headers) assert len(buf) == 0 # Call set_eof without body - headers should be sent msg.set_eof() # Headers should be sent assert len(buf) > 0 assert b"GET /empty HTTP/1.1\r\n" in buf async def test_write_payload_eof( transport: asyncio.Transport, protocol: BaseProtocol, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) await msg.write(b"data1") await msg.write(b"data2") await msg.write_eof() content = b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert b"data1data2" == content.split(b"\r\n\r\n", 1)[-1] async def test_write_payload_chunked( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b"data") await msg.write_eof() assert b"4\r\ndata\r\n0\r\n\r\n" == buf async def test_write_payload_chunked_multiple( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b"data1") await msg.write(b"data2") await msg.write_eof() assert b"5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n" == buf async def test_write_payload_length( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.length = 2 await msg.write(b"d") await msg.write(b"ata") await msg.write_eof() content = b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert b"da" == content.split(b"\r\n\r\n", 1)[-1] @pytest.mark.usefixtures("disable_writelines") @pytest.mark.internal # Used for performance benchmarking async def test_write_large_payload_deflate_compression_data_in_eof( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") await msg.write(b"data" * 4096) assert transport.write.called # type: ignore[attr-defined] chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] transport.write.reset_mock() # type: ignore[attr-defined] # This payload compresses to 20447 bytes payload = b"".join( [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] ) await msg.write_eof(payload) chunks.extend([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert zlib.decompress(content) == (b"data" * 4096) + payload @pytest.mark.usefixtures("disable_writelines") @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_large_payload_deflate_compression_data_in_eof_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") await msg.write(b"data" * 4096) # Behavior depends on zlib backend, isal compress() returns b'' initially # and the entire compressed bytes at flush() for this data backend_to_write_called = { "isal.isal_zlib": False, "zlib": True, "zlib_ng.zlib_ng": True, } assert transport.write.called == backend_to_write_called[ZLibBackend.name] # type: ignore[attr-defined] chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] transport.write.reset_mock() # type: ignore[attr-defined] # This payload compresses to 20447 bytes payload = b"".join( [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] ) await msg.write_eof(payload) chunks.extend([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert ZLibBackend.decompress(content) == (b"data" * 4096) + payload @pytest.mark.usefixtures("enable_writelines") @pytest.mark.internal # Used for performance benchmarking async def test_write_large_payload_deflate_compression_data_in_eof_writelines( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") await msg.write(b"data" * 4096) assert transport.write.called # type: ignore[attr-defined] chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] transport.write.reset_mock() # type: ignore[attr-defined] assert not transport.writelines.called # type: ignore[attr-defined] # This payload compresses to 20447 bytes payload = b"".join( [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] ) await msg.write_eof(payload) assert not transport.write.called # type: ignore[attr-defined] assert transport.writelines.called # type: ignore[attr-defined] chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined] content = b"".join(chunks) assert zlib.decompress(content) == (b"data" * 4096) + payload @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_large_payload_deflate_compression_data_in_eof_writelines_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") await msg.write(b"data" * 4096) # Behavior depends on zlib backend, isal compress() returns b'' initially # and the entire compressed bytes at flush() for this data backend_to_write_called = { "isal.isal_zlib": False, "zlib": True, "zlib_ng.zlib_ng": True, } assert transport.write.called == backend_to_write_called[ZLibBackend.name] # type: ignore[attr-defined] chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] transport.write.reset_mock() # type: ignore[attr-defined] assert not transport.writelines.called # type: ignore[attr-defined] # This payload compresses to 20447 bytes payload = b"".join( [bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)] ) await msg.write_eof(payload) assert transport.writelines.called != transport.write.called # type: ignore[attr-defined] if transport.writelines.called: # type: ignore[attr-defined] chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined] else: # transport.write.called: # type: ignore[attr-defined] chunks.extend([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] content = b"".join(chunks) assert ZLibBackend.decompress(content) == (b"data" * 4096) + payload async def test_write_payload_chunked_filter( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b"da") await msg.write(b"ta") await msg.write_eof() content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined] content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n") async def test_write_payload_chunked_filter_multiple_chunks( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b"da") await msg.write(b"ta") await msg.write(b"1d") await msg.write(b"at") await msg.write(b"a2") await msg.write_eof() content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined] content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined] assert content.endswith( b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n2\r\na2\r\n0\r\n\r\n" ) @pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") await msg.write(b"data") await msg.write_eof() chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1] @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_payload_deflate_compression_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") await msg.write(b"data") await msg.write_eof() chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert b"data" == decompress(content) @pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") await msg.write_eof() chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert content == expected @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_payload_deflate_compression_chunked_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") await msg.write_eof() chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert b"data" == decompress(decode_chunked(content)) @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") @pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked_writelines( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") await msg.write_eof() chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert content == expected @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_payload_deflate_compression_chunked_writelines_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") await msg.write_eof() chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert b"data" == decompress(decode_chunked(content)) @pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_and_chunked( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"da") await msg.write(b"ta") await msg.write_eof() thing = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n" assert thing == buf @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_payload_deflate_and_chunked_all_zlib( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"da") await msg.write(b"ta") await msg.write_eof() assert b"data" == decompress(decode_chunked(buf)) @pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked_data_in_eof( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") await msg.write_eof(b"end") chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert content == expected @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_payload_deflate_compression_chunked_data_in_eof_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") await msg.write_eof(b"end") chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert b"dataend" == decompress(decode_chunked(content)) @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") @pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") await msg.write_eof(b"end") chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert content == expected @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") await msg.write_eof(b"end") chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined] assert all(chunks) content = b"".join(chunks) assert b"dataend" == decompress(decode_chunked(content)) @pytest.mark.internal # Used for performance benchmarking async def test_write_large_payload_deflate_compression_chunked_data_in_eof( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data" * 4096) # This payload compresses to 1111 bytes payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) await msg.write_eof(payload) compressed = [] chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] chunked_body = b"".join(chunks) split_body = chunked_body.split(b"\r\n") while split_body: if split_body.pop(0): compressed.append(split_body.pop(0)) content = b"".join(compressed) assert zlib.decompress(content) == (b"data" * 4096) + payload @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_large_payload_deflate_compression_chunked_data_in_eof_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data" * 4096) # This payload compresses to 1111 bytes payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) await msg.write_eof(payload) compressed = [] chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined] chunked_body = b"".join(chunks) split_body = chunked_body.split(b"\r\n") while split_body: if split_body.pop(0): compressed.append(split_body.pop(0)) content = b"".join(compressed) assert ZLibBackend.decompress(content) == (b"data" * 4096) + payload @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") @pytest.mark.internal # Used for performance benchmarking async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data" * 4096) # This payload compresses to 1111 bytes payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) await msg.write_eof(payload) assert not transport.write.called # type: ignore[attr-defined] chunks = [] for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined] chunked_payload = list(write_lines_call[1][0])[1:] chunked_payload.pop() chunks.extend(chunked_payload) assert all(chunks) content = b"".join(chunks) assert zlib.decompress(content) == (b"data" * 4096) + payload @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data" * 4096) # This payload compresses to 1111 bytes payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)]) await msg.write_eof(payload) assert not transport.write.called # type: ignore[attr-defined] chunks = [] for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined] chunked_payload = list(write_lines_call[1][0])[1:] chunked_payload.pop() chunks.extend(chunked_payload) assert all(chunks) content = b"".join(chunks) assert ZLibBackend.decompress(content) == (b"data" * 4096) + payload @pytest.mark.internal # Used for performance benchmarking async def test_write_payload_deflate_compression_chunked_connection_lost( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") with ( pytest.raises( ClientConnectionResetError, match="Cannot write to closing transport" ), mock.patch.object(transport, "is_closing", return_value=True), ): await msg.write_eof(b"end") @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_write_payload_deflate_compression_chunked_connection_lost_all_zlib( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() await msg.write(b"data") with ( pytest.raises( ClientConnectionResetError, match="Cannot write to closing transport" ), mock.patch.object(transport, "is_closing", return_value=True), ): await msg.write_eof(b"end") async def test_write_payload_bytes_memoryview( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) mv = memoryview(b"abcd") await msg.write(mv) await msg.write_eof() thing = b"abcd" assert thing == buf async def test_write_payload_short_ints_memoryview( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() payload = memoryview(array.array("H", [65, 66, 67])) await msg.write(payload) await msg.write_eof() endians = ( (b"6\r\n\x00A\x00B\x00C\r\n0\r\n\r\n"), (b"6\r\nA\x00B\x00C\x00\r\n0\r\n\r\n"), ) assert buf in endians async def test_write_payload_2d_shape_memoryview( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_chunking() mv = memoryview(b"ABCDEF") payload = mv.cast("c", [3, 2]) await msg.write(payload) await msg.write_eof() thing = b"6\r\nABCDEF\r\n0\r\n\r\n" assert thing == buf async def test_write_payload_slicing_long_memoryview( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.length = 4 mv = memoryview(b"ABCDEF") payload = mv.cast("c", [3, 2]) await msg.write(payload) await msg.write_eof() thing = b"ABCD" assert thing == buf async def test_write_drain( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) with mock.patch.object(msg, "drain", autospec=True, spec_set=True) as m: await msg.write(b"1" * (64 * 1024 * 2), drain=False) assert not m.called await msg.write(b"1", drain=True) assert m.called assert msg.buffer_size == 0 # type: ignore[unreachable] async def test_write_calls_callback( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: async def on_chunk_sent(chunk: bytes) -> None: """Mock signature""" on_chunk_sent_mock = mock.create_autospec(on_chunk_sent, spec_set=True) msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent_mock) chunk = b"1" await msg.write(chunk) assert on_chunk_sent_mock.called assert on_chunk_sent_mock.call_args == mock.call(chunk) async def test_write_eof_calls_callback( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: async def on_chunk_sent(chunk: bytes) -> None: """Mock signature""" on_chunk_sent_mock = mock.create_autospec(on_chunk_sent, spec_set=True) msg = http.StreamWriter(protocol, loop, on_chunk_sent=on_chunk_sent_mock) chunk = b"1" await msg.write_eof(chunk=chunk) assert on_chunk_sent_mock.called assert on_chunk_sent_mock.call_args == mock.call(chunk) async def test_write_to_closing_transport( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) await msg.write(b"Before closing") transport.is_closing.return_value = True # type: ignore[attr-defined] with pytest.raises(ClientConnectionResetError): await msg.write(b"After closing") async def test_write_to_closed_transport( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that writing to a closed transport raises ClientConnectionResetError. The StreamWriter checks to see if protocol.transport is None before writing to the transport. If it is None, it raises ConnectionResetError. """ msg = http.StreamWriter(protocol, loop) await msg.write(b"Before transport close") protocol.transport = None with pytest.raises( ClientConnectionResetError, match="Cannot write to closing transport" ): await msg.write(b"After transport closed") async def test_drain( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) await msg.drain() assert protocol._drain_helper.called # type: ignore[attr-defined] async def test_drain_no_transport( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg._protocol.transport = None await msg.drain() assert not protocol._drain_helper.called # type: ignore[attr-defined] async def test_write_headers_prevents_injection( protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) status_line = "HTTP/1.1 200 OK" wrong_headers = CIMultiDict({"Set-Cookie: abc=123\r\nContent-Length": "256"}) with pytest.raises(ValueError): await msg.write_headers(status_line, wrong_headers) wrong_headers = CIMultiDict({"Content-Length": "256\r\nSet-Cookie: abc=123"}) with pytest.raises(ValueError): await msg.write_headers(status_line, wrong_headers) async def test_set_eof_after_write_headers( protocol: BaseProtocol, transport: mock.Mock, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) status_line = "HTTP/1.1 200 OK" good_headers = CIMultiDict({"Set-Cookie": "abc=123"}) # Write headers - should be buffered await msg.write_headers(status_line, good_headers) assert not transport.write.called # Headers are buffered # set_eof should send the buffered headers msg.set_eof() assert transport.write.called # Subsequent write_eof should do nothing transport.write.reset_mock() await msg.write_eof() assert not transport.write.called async def test_write_headers_does_not_write_immediately( protocol: BaseProtocol, transport: mock.Mock, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) status_line = "HTTP/1.1 200 OK" headers = CIMultiDict({"Content-Type": "text/plain"}) # write_headers should buffer, not write immediately await msg.write_headers(status_line, headers) assert not transport.write.called assert not transport.writelines.called # Headers should be sent when set_eof is called msg.set_eof() assert transport.write.called async def test_write_headers_with_compression_coalescing( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") headers = CIMultiDict({"Content-Encoding": "deflate", "Host": "example.com"}) # Write headers - should be buffered await msg.write_headers("POST /data HTTP/1.1", headers) assert len(buf) == 0 # Write compressed data via write_eof - should coalesce await msg.write_eof(b"Hello World") # Verify headers are present assert b"POST /data HTTP/1.1\r\n" in buf assert b"Content-Encoding: deflate\r\n" in buf # Verify compressed data is present # The data should contain headers + compressed payload assert len(buf) > 50 # Should have headers + some compressed data @pytest.mark.parametrize( "char", [ "\n", "\r", ], ) def test_serialize_headers_raises_on_new_line_or_carriage_return(char: str) -> None: """Verify serialize_headers raises on cr or nl in the headers.""" status_line = "HTTP/1.1 200 OK" headers = CIMultiDict( { hdrs.CONTENT_TYPE: f"text/plain{char}", } ) with pytest.raises( ValueError, match="detected in headers", ): _serialize_headers(status_line, headers) def test_serialize_headers_raises_on_null_byte() -> None: status_line = "HTTP/1.1 200 OK" headers = CIMultiDict( { hdrs.CONTENT_TYPE: "text/plain\x00", } ) with pytest.raises( ValueError, match="null byte detected in headers", ): _serialize_headers(status_line, headers) async def test_write_compressed_data_with_headers_coalescing( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that headers are coalesced with compressed data in write() method.""" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") headers = CIMultiDict({"Content-Encoding": "deflate", "Host": "example.com"}) # Write headers - should be buffered await msg.write_headers("POST /data HTTP/1.1", headers) assert len(buf) == 0 # Write compressed data - should coalesce with headers await msg.write(b"Hello World") # Headers and compressed data should be written together assert b"POST /data HTTP/1.1\r\n" in buf assert b"Content-Encoding: deflate\r\n" in buf assert len(buf) > 50 # Headers + compressed data async def test_write_compressed_chunked_with_headers_coalescing( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test headers coalescing with compressed chunked data.""" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() headers = CIMultiDict( {"Content-Encoding": "deflate", "Transfer-Encoding": "chunked"} ) # Write headers - should be buffered await msg.write_headers("POST /data HTTP/1.1", headers) assert len(buf) == 0 # Write compressed chunked data - should coalesce await msg.write(b"Hello World") # Check headers are present assert b"POST /data HTTP/1.1\r\n" in buf assert b"Transfer-Encoding: chunked\r\n" in buf # Should have chunk size marker for compressed data output = buf.decode("latin-1", errors="ignore") assert "\r\n" in output # Should have chunk markers async def test_write_multiple_compressed_chunks_after_headers_sent( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test multiple compressed writes after headers are already sent.""" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") headers = CIMultiDict({"Content-Encoding": "deflate"}) # Write headers and send them immediately by writing first chunk await msg.write_headers("POST /data HTTP/1.1", headers) assert len(buf) == 0 # Headers buffered # Write first chunk - this will send headers + compressed data await msg.write(b"First chunk of data that should compress") len_after_first = len(buf) assert len_after_first > 0 # Headers + first chunk written # Write second chunk and force flush via EOF await msg.write(b"Second chunk of data that should also compress well") await msg.write_eof() # After EOF, all compressed data should be flushed final_len = len(buf) assert final_len > len_after_first async def test_write_eof_empty_compressed_with_buffered_headers( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test write_eof with no data but compression enabled and buffered headers.""" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") headers = CIMultiDict({"Content-Encoding": "deflate"}) # Write headers - should be buffered await msg.write_headers("GET /data HTTP/1.1", headers) assert len(buf) == 0 # Write EOF with no data - should still coalesce headers with compression flush await msg.write_eof() # Headers should be present assert b"GET /data HTTP/1.1\r\n" in buf assert b"Content-Encoding: deflate\r\n" in buf # Should have compression flush data assert len(buf) > 40 async def test_write_compressed_gzip_with_headers_coalescing( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test gzip compression with header coalescing.""" msg = http.StreamWriter(protocol, loop) msg.enable_compression("gzip") headers = CIMultiDict({"Content-Encoding": "gzip"}) # Write headers - should be buffered await msg.write_headers("POST /data HTTP/1.1", headers) assert len(buf) == 0 # Write gzip compressed data via write_eof await msg.write_eof(b"Test gzip compression") # Verify coalescing happened assert b"POST /data HTTP/1.1\r\n" in buf assert b"Content-Encoding: gzip\r\n" in buf # Gzip typically produces more overhead than deflate assert len(buf) > 60 async def test_compression_with_content_length_constraint( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test compression respects content length constraints.""" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.length = 5 # Set small content length headers = CIMultiDict({"Content-Length": "5"}) await msg.write_headers("POST /data HTTP/1.1", headers) # Write some initial data to trigger headers to be sent await msg.write(b"12345") # This matches our content length of 5 headers_and_first_chunk_len = len(buf) # Try to write more data than content length allows await msg.write(b"This is a longer message") # The second write should not add any data since content length is exhausted # After writing 5 bytes, length becomes 0, so additional writes are ignored assert len(buf) == headers_and_first_chunk_len # No additional data written async def test_write_compressed_zero_length_chunk( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test writing empty chunk with compression.""" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") await msg.write_headers("POST /data HTTP/1.1", CIMultiDict()) # Force headers to be sent by writing something await msg.write(b"x") # Write something to trigger header send buf.clear() # Write empty chunk - compression may still produce output await msg.write(b"") # With compression, even empty input might produce small output # due to compression state, but it should be minimal assert len(buf) < 10 # Should be very small if anything async def test_chunked_compressed_eof_coalescing( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test chunked compressed data with EOF marker coalescing.""" msg = http.StreamWriter(protocol, loop) msg.enable_compression("deflate") msg.enable_chunking() headers = CIMultiDict( {"Content-Encoding": "deflate", "Transfer-Encoding": "chunked"} ) # Buffer headers await msg.write_headers("POST /data HTTP/1.1", headers) assert len(buf) == 0 # Write compressed chunked data with EOF await msg.write_eof(b"Final compressed chunk") # Should have headers assert b"POST /data HTTP/1.1\r\n" in buf # Should end with chunked EOF marker assert buf.endswith(b"0\r\n\r\n") # Should have chunk size in hex before the compressed data output = buf # Verify we have chunk markers - look for \r\n followed by hex digits # The chunk size should be between the headers and the compressed data assert b"\r\n\r\n" in output # End of headers # After headers, we should have a hex chunk size headers_end = output.find(b"\r\n\r\n") + 4 chunk_data = output[headers_end:] # Should start with hex digits followed by \r\n assert ( chunk_data[:10] .strip() .decode("ascii", errors="ignore") .replace("\r\n", "") .isalnum() ) async def test_compression_different_strategies( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test compression with different strategies.""" # Test with best speed strategy (default) msg1 = http.StreamWriter(protocol, loop) msg1.enable_compression("deflate") # Default strategy await msg1.write_headers("POST /fast HTTP/1.1", CIMultiDict()) await msg1.write_eof(b"Test data for compression test data for compression") buf1_len = len(buf) # Both should produce output assert buf1_len > 0 # Headers should be present assert b"POST /fast HTTP/1.1\r\n" in buf # Since we can't easily test different compression strategies # (the compressor initialization might not support strategy parameter), # we just verify that compression works async def test_chunked_headers_single_write_with_set_eof( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that set_eof combines headers and chunked EOF in single write.""" msg = http.StreamWriter(protocol, loop) msg.enable_chunking() # Write headers - should be buffered headers = CIMultiDict({"Transfer-Encoding": "chunked", "Host": "example.com"}) await msg.write_headers("GET /test HTTP/1.1", headers) assert len(buf) == 0 # Headers not sent yet assert not transport.writelines.called # type: ignore[attr-defined] # No writelines calls yet # Call set_eof - should send headers + chunked EOF in single write call msg.set_eof() # Should have exactly one write call (since payload is small, writelines falls back to write) assert transport.write.call_count == 1 # type: ignore[attr-defined] assert transport.writelines.call_count == 0 # type: ignore[attr-defined] # Not called for small payloads # The write call should have the combined headers and chunked EOF marker write_data = transport.write.call_args[0][0] # type: ignore[attr-defined] assert write_data.startswith(b"GET /test HTTP/1.1\r\n") assert b"Transfer-Encoding: chunked\r\n" in write_data assert write_data.endswith(b"\r\n\r\n0\r\n\r\n") # Headers end + chunked EOF # Verify final output assert b"GET /test HTTP/1.1\r\n" in buf assert b"Transfer-Encoding: chunked\r\n" in buf assert buf.endswith(b"0\r\n\r\n") async def test_send_headers_forces_header_write( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that send_headers() forces writing buffered headers.""" msg = http.StreamWriter(protocol, loop) headers = CIMultiDict({"Content-Length": "10", "Host": "example.com"}) # Write headers (should be buffered) await msg.write_headers("GET /test HTTP/1.1", headers) assert len(buf) == 0 # Headers buffered # Force send headers msg.send_headers() # Headers should now be written assert b"GET /test HTTP/1.1\r\n" in buf assert b"Content-Length: 10\r\n" in buf assert b"Host: example.com\r\n" in buf # Writing body should not resend headers buf.clear() await msg.write(b"0123456789") assert b"GET /test" not in buf # Headers not repeated assert buf == b"0123456789" # Just the body async def test_send_headers_idempotent( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that send_headers() is idempotent and safe to call multiple times.""" msg = http.StreamWriter(protocol, loop) headers = CIMultiDict({"Content-Length": "5", "Host": "example.com"}) # Write headers (should be buffered) await msg.write_headers("GET /test HTTP/1.1", headers) assert len(buf) == 0 # Headers buffered # Force send headers msg.send_headers() headers_output = bytes(buf) # Call send_headers again - should be no-op msg.send_headers() assert buf == headers_output # No additional output # Call send_headers after headers already sent - should be no-op await msg.write(b"hello") msg.send_headers() assert buf[len(headers_output) :] == b"hello" # Only body added async def test_send_headers_no_buffered_headers( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that send_headers() is safe when no headers are buffered.""" msg = http.StreamWriter(protocol, loop) # Call send_headers without writing headers first msg.send_headers() # Should not crash assert len(buf) == 0 # No output async def test_write_drain_condition_with_small_buffer( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that drain is not called when buffer_size <= LIMIT.""" msg = http.StreamWriter(protocol, loop) # Write headers first await msg.write_headers("GET /test HTTP/1.1", CIMultiDict()) msg.send_headers() # Send headers to start with clean state # Reset buffer size manually since send_headers doesn't do it msg.buffer_size = 0 # Reset drain helper mock protocol._drain_helper.reset_mock() # type: ignore[attr-defined] # Write small amount of data with drain=True but buffer under limit small_data = b"x" * 100 # Much less than LIMIT (2**16) await msg.write(small_data, drain=True) # Drain should NOT be called because buffer_size <= LIMIT assert not protocol._drain_helper.called # type: ignore[attr-defined] assert msg.buffer_size == 100 assert small_data in buf async def test_write_drain_condition_with_large_buffer( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that drain is called only when drain=True AND buffer_size > LIMIT.""" msg = http.StreamWriter(protocol, loop) # Write headers first await msg.write_headers("GET /test HTTP/1.1", CIMultiDict()) msg.send_headers() # Send headers to start with clean state # Reset buffer size manually since send_headers doesn't do it msg.buffer_size = 0 # Reset drain helper mock protocol._drain_helper.reset_mock() # type: ignore[attr-defined] # Write large amount of data with drain=True large_data = b"x" * (2**16 + 1) # Just over LIMIT await msg.write(large_data, drain=True) # Drain should be called because drain=True AND buffer_size > LIMIT assert protocol._drain_helper.called # type: ignore[attr-defined] assert msg.buffer_size == 0 # Buffer reset after drain assert large_data in buf async def test_write_no_drain_with_large_buffer( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that drain is not called when drain=False even with large buffer.""" msg = http.StreamWriter(protocol, loop) # Write headers first await msg.write_headers("GET /test HTTP/1.1", CIMultiDict()) msg.send_headers() # Send headers to start with clean state # Reset buffer size manually since send_headers doesn't do it msg.buffer_size = 0 # Reset drain helper mock protocol._drain_helper.reset_mock() # type: ignore[attr-defined] # Write large amount of data with drain=False large_data = b"x" * (2**16 + 1) # Just over LIMIT await msg.write(large_data, drain=False) # Drain should NOT be called because drain=False assert not protocol._drain_helper.called # type: ignore[attr-defined] assert msg.buffer_size == (2**16 + 1) # Buffer not reset assert large_data in buf async def test_set_eof_idempotent( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test that set_eof() is idempotent and can be called multiple times safely.""" msg = http.StreamWriter(protocol, loop) # Test 1: Multiple set_eof calls with buffered headers headers = CIMultiDict({"Content-Length": "0"}) await msg.write_headers("GET /test HTTP/1.1", headers) # First set_eof should send headers msg.set_eof() first_output = buf assert b"GET /test HTTP/1.1\r\n" in first_output assert b"Content-Length: 0\r\n" in first_output # Second set_eof should be no-op msg.set_eof() assert bytes(buf) == first_output # No additional output # Third set_eof should also be no-op msg.set_eof() assert bytes(buf) == first_output # Still no additional output # Test 2: set_eof with chunked encoding buf.clear() msg2 = http.StreamWriter(protocol, loop) msg2.enable_chunking() headers2 = CIMultiDict({"Transfer-Encoding": "chunked"}) await msg2.write_headers("POST /data HTTP/1.1", headers2) # First set_eof should send headers + chunked EOF msg2.set_eof() chunked_output = buf assert b"POST /data HTTP/1.1\r\n" in buf assert b"Transfer-Encoding: chunked\r\n" in buf assert b"0\r\n\r\n" in buf # Chunked EOF marker # Second set_eof should be no-op msg2.set_eof() assert buf == chunked_output # No additional output # Test 3: set_eof after headers already sent buf.clear() msg3 = http.StreamWriter(protocol, loop) headers3 = CIMultiDict({"Content-Length": "5"}) await msg3.write_headers("PUT /update HTTP/1.1", headers3) # Send headers by writing some data await msg3.write(b"hello") headers_and_body = buf # set_eof after headers sent should be no-op msg3.set_eof() assert buf == headers_and_body # No additional output # Another set_eof should still be no-op msg3.set_eof() assert buf == headers_and_body # Still no additional output async def test_non_chunked_write_empty_body( buf: bytearray, protocol: BaseProtocol, transport: mock.Mock, loop: asyncio.AbstractEventLoop, ) -> None: """Test non-chunked response with empty body.""" msg = http.StreamWriter(protocol, loop) # Non-chunked response with Content-Length: 0 headers = CIMultiDict({"Content-Length": "0"}) await msg.write_headers("GET /empty HTTP/1.1", headers) # Write empty body await msg.write(b"") # Check the output assert b"GET /empty HTTP/1.1\r\n" in buf assert b"Content-Length: 0\r\n" in buf async def test_chunked_headers_sent_with_empty_chunk_not_eof( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test chunked encoding where headers are sent without data and not EOF.""" msg = http.StreamWriter(protocol, loop) msg.enable_chunking() headers = CIMultiDict({"Transfer-Encoding": "chunked"}) await msg.write_headers("POST /upload HTTP/1.1", headers) # This should trigger the else case in _send_headers_with_payload # by having no chunk data and is_eof=False await msg.write(b"") # Headers should be sent alone assert b"POST /upload HTTP/1.1\r\n" in buf assert b"Transfer-Encoding: chunked\r\n" in buf # Should not have any chunk markers yet assert b"0\r\n" not in buf async def test_chunked_set_eof_after_headers_sent( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test chunked encoding where set_eof is called after headers already sent.""" msg = http.StreamWriter(protocol, loop) msg.enable_chunking() headers = CIMultiDict({"Transfer-Encoding": "chunked"}) await msg.write_headers("POST /data HTTP/1.1", headers) # Send headers by writing some data await msg.write(b"test data") buf.clear() # Clear buffer to check only what set_eof writes # This should trigger writing chunked EOF when headers already sent msg.set_eof() # Should only have the chunked EOF marker assert buf == b"0\r\n\r\n" @pytest.mark.usefixtures("enable_writelines") @pytest.mark.usefixtures("force_writelines_small_payloads") async def test_write_eof_chunked_with_data_using_writelines( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test write_eof with chunked data that uses writelines (line 336).""" msg = http.StreamWriter(protocol, loop) msg.enable_chunking() headers = CIMultiDict({"Transfer-Encoding": "chunked"}) await msg.write_headers("POST /data HTTP/1.1", headers) # Send headers first await msg.write(b"initial") transport.writelines.reset_mock() # type: ignore[attr-defined] # This should trigger writelines for final chunk with EOF await msg.write_eof(b"final chunk data") # Should have used writelines assert transport.writelines.called # type: ignore[attr-defined] # Get the data from writelines call writelines_data = transport.writelines.call_args[0][0] # type: ignore[attr-defined] combined = b"".join(writelines_data) # Should have chunk size, data, and EOF marker assert b"10\r\n" in combined # hex for 16 (length of "final chunk data") assert b"final chunk data" in combined assert b"0\r\n\r\n" in combined async def test_send_headers_with_payload_chunked_eof_no_data( buf: bytearray, protocol: BaseProtocol, transport: asyncio.Transport, loop: asyncio.AbstractEventLoop, ) -> None: """Test _send_headers_with_payload with chunked, is_eof=True but no chunk data.""" msg = http.StreamWriter(protocol, loop) msg.enable_chunking() headers = CIMultiDict({"Transfer-Encoding": "chunked"}) await msg.write_headers("GET /test HTTP/1.1", headers) # This triggers the elif is_eof branch in _send_headers_with_payload # by calling write_eof with empty chunk await msg.write_eof(b"") # Should have headers and chunked EOF marker together assert b"GET /test HTTP/1.1\r\n" in buf assert b"Transfer-Encoding: chunked\r\n" in buf assert buf.endswith(b"0\r\n\r\n") ================================================ FILE: tests/test_imports.py ================================================ import os import platform import sys from pathlib import Path import pytest def test___all__(pytester: pytest.Pytester) -> None: """See https://github.com/aio-libs/aiohttp/issues/6197""" pytester.makepyfile(test_a=""" from aiohttp import * assert 'GunicornWebWorker' in globals() """) result = pytester.runpytest("-vv") result.assert_outcomes(passed=0, errors=0) def test_web___all__(pytester: pytest.Pytester) -> None: pytester.makepyfile(test_b=""" from aiohttp.web import * """) result = pytester.runpytest("-vv") result.assert_outcomes(passed=0, errors=0) @pytest.mark.internal @pytest.mark.dev_mode @pytest.mark.skipif( not sys.platform.startswith("linux") or platform.python_implementation() == "PyPy", reason="Timing is more reliable on Linux", ) def test_import_time(pytester: pytest.Pytester) -> None: """Check that importing aiohttp doesn't take too long. Obviously, the time may vary on different machines and may need to be adjusted from time to time, but this should provide an early warning if something is added that significantly increases import time. Runs 3 times and keeps the minimum time to reduce flakiness. """ IMPORT_TIME_THRESHOLD_MS = 300 if sys.version_info >= (3, 12) else 200 root = Path(__file__).parent.parent old_path = os.environ.get("PYTHONPATH") os.environ["PYTHONPATH"] = os.pathsep.join([str(root)] + sys.path) best_time_ms = 1000 cmd = "import timeit; print(int(timeit.timeit('import aiohttp', number=1) * 1000))" try: for _ in range(3): r = pytester.run(sys.executable, "-We", "-c", cmd) assert not r.stderr.str(), r.stderr.str() best_time_ms = min(best_time_ms, int(r.stdout.str())) finally: if old_path is None: os.environ.pop("PYTHONPATH") else: # pragma: no cover os.environ["PYTHONPATH"] = old_path assert best_time_ms < IMPORT_TIME_THRESHOLD_MS ================================================ FILE: tests/test_leaks.py ================================================ import pathlib import platform import subprocess import sys import pytest IS_PYPY = platform.python_implementation() == "PyPy" @pytest.mark.skipif(IS_PYPY, reason="gc.DEBUG_LEAK not available on PyPy") @pytest.mark.parametrize( ("script", "message"), [ ( # Test that ClientResponse is collected after server disconnects. # https://github.com/aio-libs/aiohttp/issues/10535 "check_for_client_response_leak.py", "ClientResponse leaked", ), ( # Test that Request object is collected when the handler raises. # https://github.com/aio-libs/aiohttp/issues/10548 "check_for_request_leak.py", "Request leaked", ), ], ) def test_leak(script: str, message: str) -> None: """Run isolated leak test script and check for leaks.""" leak_test_script = pathlib.Path(__file__).parent.joinpath("isolated", script) with subprocess.Popen( [sys.executable, "-u", str(leak_test_script)], stdout=subprocess.PIPE, ) as proc: assert proc.wait() == 0, message ================================================ FILE: tests/test_loop.py ================================================ import asyncio import platform import threading import pytest from aiohttp import web from aiohttp.test_utils import AioHTTPTestCase, loop_context @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) async def test_subprocess_co(loop: asyncio.AbstractEventLoop) -> None: proc = await asyncio.create_subprocess_shell( "exit 0", stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL, ) await proc.wait() class TestCase(AioHTTPTestCase): on_startup_called: bool async def get_application(self) -> web.Application: app = web.Application() app.on_startup.append(self.on_startup_hook) return app async def on_startup_hook(self, app: web.Application) -> None: self.on_startup_called = True async def test_on_startup_hook(self) -> None: self.assertTrue(self.on_startup_called) def test_default_loop(loop: asyncio.AbstractEventLoop) -> None: assert asyncio.get_event_loop() is loop def test_setup_loop_non_main_thread() -> None: child_exc = None def target() -> None: try: with loop_context() as loop: assert asyncio.get_event_loop() is loop loop.run_until_complete(test_subprocess_co(loop)) except Exception as exc: # pragma: no cover nonlocal child_exc child_exc = exc # Ensures setup_test_loop can be called by pytest-xdist in non-main thread. t = threading.Thread(target=target) t.start() t.join() assert child_exc is None ================================================ FILE: tests/test_multipart.py ================================================ import asyncio import io import json import pathlib import sys from types import TracebackType from unittest import mock import pytest from multidict import CIMultiDict, CIMultiDictProxy import aiohttp from aiohttp import payload from aiohttp.abc import AbstractStreamWriter from aiohttp.compression_utils import ZLibBackend from aiohttp.hdrs import ( CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_TRANSFER_ENCODING, CONTENT_TYPE, ) from aiohttp.helpers import parse_mimetype from aiohttp.multipart import ( BodyPartReader, BodyPartReaderPayload, MultipartReader, MultipartResponseWrapper, ) from aiohttp.streams import StreamReader if sys.version_info >= (3, 11): from typing import Self else: from typing import TypeVar Self = TypeVar("Self", bound="Stream") BOUNDARY: bytes = b"--:" @pytest.fixture def buf() -> bytearray: return bytearray() @pytest.fixture def stream(buf: bytearray) -> AbstractStreamWriter: writer = mock.create_autospec(AbstractStreamWriter, instance=True, spec_set=True) async def write(chunk: bytes) -> None: buf.extend(chunk) writer.write.side_effect = write return writer # type: ignore[no-any-return] @pytest.fixture def buf2() -> bytearray: return bytearray() @pytest.fixture def stream2(buf2: bytearray) -> mock.Mock: writer = mock.Mock() async def write(chunk: bytes) -> None: buf2.extend(chunk) writer.write.side_effect = write return writer @pytest.fixture def writer() -> aiohttp.MultipartWriter: return aiohttp.MultipartWriter(boundary=":") class Stream(StreamReader): def __init__(self, content: bytes) -> None: self.content = io.BytesIO(content) async def read(self, size: int | None = None) -> bytes: return self.content.read(size) def at_eof(self) -> bool: return self.content.tell() == len(self.content.getbuffer()) async def readline(self, *, max_line_length: int | None = None) -> bytes: return self.content.readline() def unread_data(self, data: bytes) -> None: self.content = io.BytesIO(data + self.content.read()) def __enter__(self) -> Self: return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: self.content.close() class Response: def __init__(self, headers: CIMultiDictProxy[str], content: Stream) -> None: self.headers = headers self.content = content class StreamWithShortenRead(Stream): def __init__(self, content: bytes) -> None: self._first = True super().__init__(content) async def read(self, size: int | None = None) -> bytes: if size is not None and self._first: self._first = False size = size // 2 return await super().read(size) class TestMultipartResponseWrapper: def test_at_eof(self) -> None: m_resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True) m_stream = mock.create_autospec(MultipartReader, spec_set=True) wrapper = MultipartResponseWrapper(m_resp, m_stream) wrapper.at_eof() assert m_resp.content.at_eof.called async def test_next(self) -> None: m_resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True) m_stream = mock.create_autospec(MultipartReader, spec_set=True) wrapper = MultipartResponseWrapper(m_resp, m_stream) m_stream.next.return_value = b"" m_stream.at_eof.return_value = False await wrapper.next() assert m_stream.next.called async def test_release(self) -> None: m_resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True) m_stream = mock.create_autospec(MultipartReader, spec_set=True) wrapper = MultipartResponseWrapper(m_resp, m_stream) await wrapper.release() assert m_resp.release.called async def test_release_when_stream_at_eof(self) -> None: m_resp = mock.create_autospec(aiohttp.ClientResponse, spec_set=True) m_stream = mock.create_autospec(MultipartReader, spec_set=True) wrapper = MultipartResponseWrapper(m_resp, m_stream) m_stream.next.return_value = b"" m_stream.at_eof.return_value = True await wrapper.next() assert m_stream.next.called assert m_resp.release.called class TestPartReader: async def test_next(self) -> None: with Stream(b"Hello, world!\r\n--:") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.next() assert b"Hello, world!" == result assert obj.at_eof() async def test_next_next(self) -> None: with Stream(b"Hello, world!\r\n--:") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.next() assert b"Hello, world!" == result assert obj.at_eof() result = await obj.next() assert result is None async def test_read(self) -> None: with Stream(b"Hello, world!\r\n--:") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.read() assert b"Hello, world!" == result assert obj.at_eof() async def test_read_chunk_at_eof(self) -> None: with Stream(b"--:") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) obj._at_eof = True result = await obj.read_chunk() assert b"" == result async def test_read_chunk_without_content_length(self) -> None: with Stream(b"Hello, world!\r\n--:") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) c1 = await obj.read_chunk(8) c2 = await obj.read_chunk(8) c3 = await obj.read_chunk(8) assert c1 + c2 == b"Hello, world!" assert c3 == b"" async def test_read_incomplete_chunk(self) -> None: with Stream(b"") as stream: def prepare(data: bytes) -> bytes: return data with mock.patch.object( stream, "read", side_effect=[ prepare(b"Hello, "), prepare(b"World"), prepare(b"!\r\n--:"), prepare(b""), ], ): d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) c1 = await obj.read_chunk(8) assert c1 == b"Hello, " c2 = await obj.read_chunk(8) assert c2 == b"World" c3 = await obj.read_chunk(8) assert c3 == b"!" async def test_read_all_at_once(self) -> None: with Stream(b"Hello, World!\r\n--:--\r\n") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.read_chunk() assert b"Hello, World!" == result result = await obj.read_chunk() assert b"" == result assert obj.at_eof() async def test_read_incomplete_body_chunked(self) -> None: with Stream(b"Hello, World!\r\n-") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = b"" with pytest.raises(ValueError): for _ in range(4): # pragma: no branch result += await obj.read_chunk(7) assert b"Hello, World!\r\n-" == result async def test_read_with_content_length_malformed_crlf(self) -> None: # Content-Length is correct but data after content is not \r\n content = b"Hello" h = CIMultiDictProxy(CIMultiDict({"CONTENT-LENGTH": str(len(content))})) # Malformed: "XX" instead of "\r\n" after content with Stream(content + b"XX--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) with pytest.raises(ValueError, match="malformed"): await obj.read() async def test_read_boundary_with_incomplete_chunk(self) -> None: with Stream(b"") as stream: def prepare(data: bytes) -> bytes: return data with mock.patch.object( stream, "read", side_effect=[ prepare(b"Hello, World"), prepare(b"!\r\n"), prepare(b"--:"), prepare(b""), ], ): d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) c1 = await obj.read_chunk(12) assert c1 == b"Hello, World" c2 = await obj.read_chunk(8) assert c2 == b"!" c3 = await obj.read_chunk(8) assert c3 == b"" async def test_multi_read_chunk(self) -> None: with Stream(b"Hello,\r\n--:\r\n\r\nworld!\r\n--:--") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.read_chunk(8) assert b"Hello," == result result = await obj.read_chunk(8) assert b"" == result assert obj.at_eof() async def test_read_chunk_properly_counts_read_bytes(self) -> None: expected = b"." * 10 size = len(expected) h = CIMultiDictProxy(CIMultiDict({"CONTENT-LENGTH": str(size)})) with StreamWithShortenRead(expected + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = bytearray() while True: chunk = await obj.read_chunk() if not chunk: break result.extend(chunk) assert size == len(result) assert b"." * size == result assert obj.at_eof() async def test_read_does_not_read_boundary(self) -> None: with Stream(b"Hello, world!\r\n--:") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.read() assert b"Hello, world!" == result assert b"--:" == (await stream.read()) async def test_multiread(self) -> None: with Stream(b"Hello,\r\n--:\r\n\r\nworld!\r\n--:--") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.read() assert b"Hello," == result result = await obj.read() assert b"" == result assert obj.at_eof() async def test_read_multiline(self) -> None: with Stream(b"Hello\n,\r\nworld!\r\n--:--") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.read() assert b"Hello\n,\r\nworld!" == result result = await obj.read() assert b"" == result assert obj.at_eof() async def test_read_respects_content_length(self) -> None: h = CIMultiDictProxy(CIMultiDict({"CONTENT-LENGTH": "100500"})) with Stream(b"." * 100500 + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read() assert b"." * 100500 == result assert obj.at_eof() async def test_read_with_content_encoding_gzip(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "gzip"})) with Stream( b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\x0b\xc9\xccMU" b"(\xc9W\x08J\xcdI\xacP\x04\x00$\xfb\x9eV\x0e\x00\x00\x00" b"\r\n--:--" ) as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) assert b"Time to Relax!" == result async def test_read_with_content_encoding_deflate(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "deflate"})) with Stream(b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) assert b"Time to Relax!" == result async def test_read_with_content_encoding_identity(self) -> None: thing = ( b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\x0b\xc9\xccMU" b"(\xc9W\x08J\xcdI\xacP\x04\x00$\xfb\x9eV\x0e\x00\x00\x00" b"\r\n" ) h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "identity"})) with Stream(thing + b"--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) assert thing[:-2] == result async def test_read_with_content_encoding_unknown(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "snappy"})) with Stream(b"\x0e4Time to Relax!\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) with pytest.raises(RuntimeError): await obj.read(decode=True) async def test_read_with_content_transfer_encoding_base64(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TRANSFER_ENCODING: "base64"})) with Stream(b"VGltZSB0byBSZWxheCE=\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) assert b"Time to Relax!" == result async def test_decode_with_content_transfer_encoding_base64(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TRANSFER_ENCODING: "base64"})) with Stream(b"VG\r\r\nltZSB0byBSZ\r\nWxheCE=\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = b"" while not obj.at_eof(): chunk = await obj.read_chunk(size=6) result += obj.decode(chunk) assert b"Time to Relax!" == result async def test_decode_iter_with_content_transfer_encoding_base64(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TRANSFER_ENCODING: "base64"})) with Stream(b"VG\r\r\nltZSB0byBSZ\r\nWxheCE=\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = b"" while not obj.at_eof(): chunk = await obj.read_chunk(size=6) async for decoded_chunk in obj.decode_iter(chunk): result += decoded_chunk assert b"Time to Relax!" == result async def test_decode_with_content_encoding_deflate(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "deflate"})) data = b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00" with Stream(data + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) chunk = await obj.read_chunk(size=len(data)) result = obj.decode(chunk) assert b"Time to Relax!" == result async def test_decode_with_content_encoding_identity(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "identity"})) data = b"Time to Relax!" with Stream(data + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) chunk = await obj.read_chunk(size=len(data)) result = obj.decode(chunk) assert data == result async def test_decode_with_content_encoding_unknown(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_ENCODING: "snappy"})) data = b"Time to Relax!" with Stream(data + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) chunk = await obj.read_chunk(size=len(data)) with pytest.raises(RuntimeError, match="unknown content encoding"): obj.decode(chunk) async def test_read_with_content_transfer_encoding_quoted_printable(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TRANSFER_ENCODING: "quoted-printable"}) ) with Stream( b"=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82, =D0=BC=D0=B8=D1=80!\r\n--:--" ) as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) expected = ( b"\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82," b" \xd0\xbc\xd0\xb8\xd1\x80!" ) assert result == expected @pytest.mark.parametrize("encoding", ("binary", "8bit", "7bit")) async def test_read_with_content_transfer_encoding_binary( self, encoding: str ) -> None: data = ( b"\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82," b" \xd0\xbc\xd0\xb8\xd1\x80!" ) h = CIMultiDictProxy(CIMultiDict({CONTENT_TRANSFER_ENCODING: encoding})) with Stream(data + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.read(decode=True) assert data == result async def test_read_with_content_transfer_encoding_unknown(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TRANSFER_ENCODING: "unknown"})) with Stream(b"\x0e4Time to Relax!\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) with pytest.raises(RuntimeError): await obj.read(decode=True) async def test_read_text(self) -> None: with Stream(b"Hello, world!\r\n--:--") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.text() assert "Hello, world!" == result async def test_read_text_default_encoding(self) -> None: with Stream("Привет, Мир!\r\n--:--".encode()) as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.text() assert "Привет, Мир!" == result async def test_read_text_encoding(self) -> None: with Stream("Привет, Мир!\r\n--:--".encode("cp1251")) as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.text(encoding="cp1251") assert "Привет, Мир!" == result async def test_read_text_guess_encoding(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "text/plain;charset=cp1251"})) with Stream("Привет, Мир!\r\n--:--".encode("cp1251")) as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.text() assert "Привет, Мир!" == result async def test_read_text_compressed(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_ENCODING: "deflate", CONTENT_TYPE: "text/plain"}) ) with Stream(b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.text() assert "Time to Relax!" == result async def test_read_text_while_closed(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "text/plain"})) with Stream(b"") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) obj._at_eof = True result = await obj.text() assert "" == result async def test_read_json(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "application/json"})) with Stream(b'{"test": "passed"}\r\n--:--') as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.json() assert {"test": "passed"} == result async def test_read_json_encoding(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "application/json"})) with Stream('{"тест": "пассед"}\r\n--:--'.encode("cp1251")) as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.json(encoding="cp1251") assert {"тест": "пассед"} == result async def test_read_json_guess_encoding(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: "application/json; charset=cp1251"}) ) with Stream('{"тест": "пассед"}\r\n--:--'.encode("cp1251")) as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.json() assert {"тест": "пассед"} == result async def test_read_json_compressed(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_ENCODING: "deflate", CONTENT_TYPE: "application/json"}) ) with Stream(b"\xabV*I-.Q\xb2RP*H,.NMQ\xaa\x05\x00\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.json() assert {"test": "passed"} == result async def test_read_json_while_closed(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "application/json"})) with Stream(b"") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) obj._at_eof = True result = await obj.json() assert result is None async def test_read_form(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: "application/x-www-form-urlencoded"}) ) with Stream(b"foo=bar&foo=baz&boo=\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.form() assert [("foo", "bar"), ("foo", "baz"), ("boo", "")] == result async def test_read_form_invalid_utf8(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: "application/x-www-form-urlencoded"}) ) with Stream(b"\xff\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) with pytest.raises( ValueError, match="data cannot be decoded with utf-8 encoding" ): await obj.form() async def test_read_form_encoding(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: "application/x-www-form-urlencoded"}) ) with Stream("foo=bar&foo=baz&boo=\r\n--:--".encode("cp1251")) as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.form(encoding="cp1251") assert [("foo", "bar"), ("foo", "baz"), ("boo", "")] == result async def test_read_form_guess_encoding(self) -> None: h = CIMultiDictProxy( CIMultiDict( {CONTENT_TYPE: "application/x-www-form-urlencoded; charset=utf-8"} ) ) with Stream(b"foo=bar&foo=baz&boo=\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) result = await obj.form() assert [("foo", "bar"), ("foo", "baz"), ("boo", "")] == result async def test_read_form_while_closed(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: "application/x-www-form-urlencoded"}) ) with Stream(b"") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) obj._at_eof = True result = await obj.form() assert not result async def test_readline(self) -> None: with Stream(b"Hello\n,\r\nworld!\r\n--:--") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) result = await obj.readline() assert b"Hello\n" == result result = await obj.readline() assert b",\r\n" == result result = await obj.readline() assert b"world!" == result result = await obj.readline() assert b"" == result assert obj.at_eof() async def test_release(self) -> None: with Stream(b"Hello,\r\n--:\r\n\r\nworld!\r\n--:--") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) await obj.release() assert obj.at_eof() assert b"--:\r\n\r\nworld!\r\n--:--" == stream.content.read() async def test_release_respects_content_length(self) -> None: h = CIMultiDictProxy(CIMultiDict({"CONTENT-LENGTH": "100500"})) with Stream(b"." * 100500 + b"\r\n--:--") as stream: obj = aiohttp.BodyPartReader(BOUNDARY, h, stream) await obj.release() assert obj.at_eof() async def test_release_release(self) -> None: with Stream(b"Hello,\r\n--:\r\n\r\nworld!\r\n--:--") as stream: d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) await obj.release() await obj.release() assert b"--:\r\n\r\nworld!\r\n--:--" == stream.content.read() async def test_filename(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_DISPOSITION: "attachment; filename=foo.html"}) ) part = aiohttp.BodyPartReader(BOUNDARY, h, mock.Mock()) assert "foo.html" == part.filename async def test_reading_long_part(self) -> None: size = 2 * 2**16 protocol = mock.Mock(_reading_paused=False) stream = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) stream.feed_data(b"0" * size + b"\r\n--:--") stream.feed_eof() d = CIMultiDictProxy[str](CIMultiDict()) obj = aiohttp.BodyPartReader(BOUNDARY, d, stream) data = await obj.read() assert len(data) == size class TestMultipartReader: def test_from_response(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: 'multipart/related;boundary=":"'}) ) with Stream(b"--:\r\n\r\nhello\r\n--:--") as stream: resp = Response(h, stream) res = aiohttp.MultipartReader.from_response(resp) # type: ignore[arg-type] assert isinstance(res, MultipartResponseWrapper) assert isinstance(res.stream, aiohttp.MultipartReader) def test_bad_boundary(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: "multipart/related;boundary=" + "a" * 80}) ) with Stream(b"") as stream: resp = Response(h, stream) with pytest.raises(ValueError): aiohttp.MultipartReader.from_response(resp) # type: ignore[arg-type] def test_dispatch(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "text/plain"})) with Stream(b"--:\r\n\r\necho\r\n--:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) res = reader._get_part_reader(h) assert isinstance(res, reader.part_reader_cls) def test_dispatch_bodypart(self) -> None: h = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "text/plain"})) with Stream(b"--:\r\n\r\necho\r\n--:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) res = reader._get_part_reader(h) assert isinstance(res, reader.part_reader_cls) def test_dispatch_multipart(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: "multipart/related;boundary=--:--"}) ) with Stream( b"----:--\r\n" b"\r\n" b"test\r\n" b"----:--\r\n" b"\r\n" b"passed\r\n" b"----:----\r\n" b"--:--" ) as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) res = reader._get_part_reader(h) assert isinstance(res, reader.__class__) def test_dispatch_custom_multipart_reader(self) -> None: class CustomReader(aiohttp.MultipartReader): pass h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: "multipart/related;boundary=--:--"}) ) with Stream( b"----:--\r\n" b"\r\n" b"test\r\n" b"----:--\r\n" b"\r\n" b"passed\r\n" b"----:----\r\n" b"--:--" ) as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) reader.multipart_reader_cls = CustomReader res = reader._get_part_reader(h) assert isinstance(res, CustomReader) async def test_emit_next(self) -> None: h = CIMultiDictProxy( CIMultiDict({CONTENT_TYPE: 'multipart/related;boundary=":"'}) ) with Stream(b"--:\r\n\r\necho\r\n--:--") as stream: reader = aiohttp.MultipartReader(h, stream) res = await reader.next() assert isinstance(res, reader.part_reader_cls) async def test_invalid_boundary(self) -> None: with Stream(b"---:\r\n\r\necho\r\n---:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) with pytest.raises(ValueError): await reader.next() async def test_read_boundary_across_chunks(self) -> None: class SplitBoundaryStream(StreamReader): def __init__(self) -> None: self.content = [ b"--foobar\r\n\r\n", b"Hello,\r\n-", b"-fo", b"ob", b"ar\r\n", b"\r\nwor", b"ld!", b"\r\n--f", b"oobar--", ] async def read(self, size: int | None = None) -> bytes: chunk = self.content.pop(0) assert size is not None and len(chunk) <= size return chunk def at_eof(self) -> bool: return not self.content async def readline(self, *, max_line_length: int | None = None) -> bytes: line = b"" while self.content and b"\n" not in line: line += self.content.pop(0) line, *extra = line.split(b"\n", maxsplit=1) if extra and extra[0]: self.content.insert(0, extra[0]) return line + b"\n" def unread_data(self, data: bytes) -> None: if self.content: self.content[0] = data + self.content[0] else: self.content.append(data) stream = SplitBoundaryStream() reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary="foobar"'}, stream ) part = await anext(reader) assert isinstance(part, BodyPartReader) result = await part.read_chunk(10) assert result == b"Hello," result = await part.read_chunk(10) assert result == b"" assert part.at_eof() part = await anext(reader) assert isinstance(part, BodyPartReader) result = await part.read_chunk(10) assert result == b"world!" result = await part.read_chunk(10) assert result == b"" assert part.at_eof() with pytest.raises(StopAsyncIteration): await anext(reader) async def test_release(self) -> None: with Stream( b"--:\r\n" b"Content-Type: multipart/related;boundary=--:--\r\n" b"\r\n" b"----:--\r\n" b"\r\n" b"test\r\n" b"----:--\r\n" b"\r\n" b"passed\r\n" b"----:----\r\n" b"\r\n" b"--:--" ) as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/mixed;boundary=":"'}, stream, ) await reader.release() assert reader.at_eof() async def test_release_release(self) -> None: with Stream(b"--:\r\n\r\necho\r\n--:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) await reader.release() assert reader.at_eof() await reader.release() assert reader.at_eof() async def test_release_next(self) -> None: with Stream(b"--:\r\n\r\necho\r\n--:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) await reader.release() assert reader.at_eof() res = await reader.next() assert res is None async def test_second_next_releases_previous_object(self) -> None: with Stream(b"--:\r\n\r\ntest\r\n--:\r\n\r\npassed\r\n--:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) first = await reader.next() assert isinstance(first, aiohttp.BodyPartReader) second = await reader.next() assert second is not None assert first.at_eof() assert not second.at_eof() async def test_release_without_read_the_last_object(self) -> None: with Stream(b"--:\r\n\r\ntest\r\n--:\r\n\r\npassed\r\n--:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) first = await reader.next() second = await reader.next() third = await reader.next() assert first is not None assert second is not None assert first.at_eof() assert second.at_eof() assert second.at_eof() assert third is None async def test_read_chunk_by_length_doesnt_break_reader(self) -> None: with Stream( b"--:\r\n" b"Content-Length: 4\r\n\r\n" b"test" b"\r\n--:\r\n" b"Content-Length: 6\r\n\r\n" b"passed" b"\r\n--:--" ) as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) body_parts = [] while True: read_part = b"" part = await reader.next() if part is None: break assert isinstance(part, BodyPartReader) while not part.at_eof(): read_part += await part.read_chunk(3) body_parts.append(read_part) assert body_parts == [b"test", b"passed"] async def test_read_chunk_from_stream_doesnt_break_reader(self) -> None: with Stream( b"--:\r\n" b"\r\n" b"chunk" b"\r\n--:\r\n" b"\r\n" b"two_chunks" b"\r\n--:--" ) as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) body_parts = [] while True: read_part = b"" part = await reader.next() if part is None: break assert isinstance(part, BodyPartReader) while not part.at_eof(): chunk = await part.read_chunk(5) assert chunk read_part += chunk body_parts.append(read_part) assert body_parts == [b"chunk", b"two_chunks"] async def test_reading_skips_prelude(self) -> None: with Stream( b"Multi-part data is not supported.\r\n" b"\r\n" b"--:\r\n" b"\r\n" b"test\r\n" b"--:\r\n" b"\r\n" b"passed\r\n" b"--:--" ) as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) first = await reader.next() assert isinstance(first, aiohttp.BodyPartReader) second = await reader.next() assert isinstance(second, BodyPartReader) assert first.at_eof() assert not second.at_eof() async def test_read_empty_body_part(self) -> None: with Stream(b"--:\r\n\r\n--:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) body_parts = [] async for part in reader: assert isinstance(part, BodyPartReader) body_parts.append(await part.read()) assert body_parts == [b""] async def test_read_body_part_headers_only(self) -> None: with Stream(b"--:\r\nContent-Type: text/plain\r\n\r\n--:--") as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/related;boundary=":"'}, stream, ) body_parts = [] async for part in reader: assert isinstance(part, BodyPartReader) assert "Content-Type" in part.headers body_parts.append(await part.read()) assert body_parts == [b""] async def test_read_form_default_encoding(self) -> None: with Stream( b"--:\r\n" b'Content-Disposition: form-data; name="_charset_"\r\n\r\n' b"ascii" b"\r\n" b"--:\r\n" b'Content-Disposition: form-data; name="field1"\r\n\r\n' b"foo" b"\r\n" b"--:\r\n" b"Content-Type: text/plain;charset=UTF-8\r\n" b'Content-Disposition: form-data; name="field2"\r\n\r\n' b"foo" b"\r\n" b"--:\r\n" b'Content-Disposition: form-data; name="field3"\r\n\r\n' b"foo" b"\r\n" ) as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/form-data;boundary=":"'}, stream, ) field1 = await reader.next() assert isinstance(field1, BodyPartReader) assert field1.name == "field1" assert field1.get_charset("default") == "ascii" field2 = await reader.next() assert isinstance(field2, BodyPartReader) assert field2.name == "field2" assert field2.get_charset("default") == "UTF-8" field3 = await reader.next() assert isinstance(field3, BodyPartReader) assert field3.name == "field3" assert field3.get_charset("default") == "ascii" async def test_read_form_invalid_default_encoding(self) -> None: with Stream( b"--:\r\n" b'Content-Disposition: form-data; name="_charset_"\r\n\r\n' b"this-value-is-too-long-to-be-a-charset" b"\r\n" b"--:\r\n" b'Content-Disposition: form-data; name="field1"\r\n\r\n' b"foo" b"\r\n" ) as stream: reader = aiohttp.MultipartReader( {CONTENT_TYPE: 'multipart/form-data;boundary=":"'}, stream, ) with pytest.raises(RuntimeError, match="Invalid default charset"): await reader.next() async def test_writer(writer: aiohttp.MultipartWriter) -> None: assert writer.size == 7 assert writer.boundary == ":" async def test_writer_serialize_io_chunk( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: with io.BytesIO(b"foobarbaz") as file_handle: writer.append(file_handle) await writer.write(stream) assert ( buf == b"--:\r\nContent-Type: application/octet-stream" b"\r\nContent-Length: 9\r\n\r\nfoobarbaz\r\n--:--\r\n" ) async def test_writer_serialize_json( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append_json({"привет": "мир"}) await writer.write(stream) assert ( b'{"\\u043f\\u0440\\u0438\\u0432\\u0435\\u0442":' b' "\\u043c\\u0438\\u0440"}' in buf ) async def test_writer_serialize_form( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: data = [("foo", "bar"), ("foo", "baz"), ("boo", "zoo")] writer.append_form(data) await writer.write(stream) assert b"foo=bar&foo=baz&boo=zoo" in buf async def test_writer_serialize_form_dict( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: data = {"hello": "мир"} writer.append_form(data) await writer.write(stream) assert b"hello=%D0%BC%D0%B8%D1%80" in buf async def test_writer_write( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append("foo-bar-baz") writer.append_json({"test": "passed"}) writer.append_form({"test": "passed"}) writer.append_form([("one", "1"), ("two", "2")]) sub_multipart = aiohttp.MultipartWriter(boundary="::") sub_multipart.append("nested content") sub_multipart.headers["X-CUSTOM"] = "test" writer.append(sub_multipart) await writer.write(stream) assert ( b"--:\r\n" b"Content-Type: text/plain; charset=utf-8\r\n" b"Content-Length: 11\r\n\r\n" b"foo-bar-baz" b"\r\n" b"--:\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 18\r\n\r\n" b'{"test": "passed"}' b"\r\n" b"--:\r\n" b"Content-Type: application/x-www-form-urlencoded\r\n" b"Content-Length: 11\r\n\r\n" b"test=passed" b"\r\n" b"--:\r\n" b"Content-Type: application/x-www-form-urlencoded\r\n" b"Content-Length: 11\r\n\r\n" b"one=1&two=2" b"\r\n" b"--:\r\n" b'Content-Type: multipart/mixed; boundary="::"\r\n' b"X-CUSTOM: test\r\nContent-Length: 93\r\n\r\n" b"--::\r\n" b"Content-Type: text/plain; charset=utf-8\r\n" b"Content-Length: 14\r\n\r\n" b"nested content\r\n" b"--::--\r\n" b"\r\n" b"--:--\r\n" ) == bytes(buf) async def test_writer_write_no_close_boundary( buf: bytearray, stream: AbstractStreamWriter ) -> None: writer = aiohttp.MultipartWriter(boundary=":") writer.append("foo-bar-baz") writer.append_json({"test": "passed"}) writer.append_form({"test": "passed"}) writer.append_form([("one", "1"), ("two", "2")]) await writer.write(stream, close_boundary=False) assert ( b"--:\r\n" b"Content-Type: text/plain; charset=utf-8\r\n" b"Content-Length: 11\r\n\r\n" b"foo-bar-baz" b"\r\n" b"--:\r\n" b"Content-Type: application/json\r\n" b"Content-Length: 18\r\n\r\n" b'{"test": "passed"}' b"\r\n" b"--:\r\n" b"Content-Type: application/x-www-form-urlencoded\r\n" b"Content-Length: 11\r\n\r\n" b"test=passed" b"\r\n" b"--:\r\n" b"Content-Type: application/x-www-form-urlencoded\r\n" b"Content-Length: 11\r\n\r\n" b"one=1&two=2" b"\r\n" ) == bytes(buf) async def test_writer_write_no_parts( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: await writer.write(stream) assert b"--:--\r\n" == bytes(buf) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_writer_serialize_with_content_encoding_gzip( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter, ) -> None: writer.append("Time to Relax!", {CONTENT_ENCODING: "gzip"}) await writer.write(stream) headers, message = bytes(buf).split(b"\r\n\r\n", 1) assert ( b"--:\r\nContent-Type: text/plain; charset=utf-8\r\n" b"Content-Encoding: gzip" == headers ) decompressor = ZLibBackend.decompressobj(wbits=16 + ZLibBackend.MAX_WBITS) data = decompressor.decompress(message.split(b"\r\n")[0]) data += decompressor.flush() assert b"Time to Relax!" == data async def test_writer_serialize_with_content_encoding_deflate( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append("Time to Relax!", {CONTENT_ENCODING: "deflate"}) await writer.write(stream) headers, message = bytes(buf).split(b"\r\n\r\n", 1) assert ( b"--:\r\nContent-Type: text/plain; charset=utf-8\r\n" b"Content-Encoding: deflate" == headers ) thing = b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--\r\n" assert thing == message async def test_writer_serialize_with_content_encoding_identity( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: thing = b"\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00" writer.append(thing, {CONTENT_ENCODING: "identity"}) await writer.write(stream) headers, message = bytes(buf).split(b"\r\n\r\n", 1) assert ( b"--:\r\nContent-Type: application/octet-stream\r\n" b"Content-Encoding: identity\r\n" b"Content-Length: 16" == headers ) assert thing == message.split(b"\r\n")[0] def test_writer_serialize_with_content_encoding_unknown( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: with pytest.raises(RuntimeError): writer.append("Time to Relax!", {CONTENT_ENCODING: "snappy"}) async def test_writer_with_content_transfer_encoding_base64( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append("Time to Relax!", {CONTENT_TRANSFER_ENCODING: "base64"}) await writer.write(stream) headers, message = bytes(buf).split(b"\r\n\r\n", 1) assert ( b"--:\r\nContent-Type: text/plain; charset=utf-8\r\n" b"Content-Transfer-Encoding: base64" == headers ) assert b"VGltZSB0byBSZWxheCE=" == message.split(b"\r\n")[0] async def test_writer_content_transfer_encoding_quote_printable( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: writer.append("Привет, мир!", {CONTENT_TRANSFER_ENCODING: "quoted-printable"}) await writer.write(stream) headers, message = bytes(buf).split(b"\r\n\r\n", 1) assert ( b"--:\r\nContent-Type: text/plain; charset=utf-8\r\n" b"Content-Transfer-Encoding: quoted-printable" == headers ) assert ( b"=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82," b" =D0=BC=D0=B8=D1=80!" == message.split(b"\r\n")[0] ) def test_writer_content_transfer_encoding_unknown( buf: bytearray, stream: AbstractStreamWriter, writer: aiohttp.MultipartWriter ) -> None: with pytest.raises(RuntimeError): writer.append("Time to Relax!", {CONTENT_TRANSFER_ENCODING: "unknown"}) class TestMultipartWriter: def test_default_subtype(self, writer: aiohttp.MultipartWriter) -> None: mimetype = parse_mimetype(writer.headers.get(CONTENT_TYPE)) assert "multipart" == mimetype.type assert "mixed" == mimetype.subtype def test_unquoted_boundary(self) -> None: writer = aiohttp.MultipartWriter(boundary="abc123") expected = {CONTENT_TYPE: "multipart/mixed; boundary=abc123"} assert expected == writer.headers def test_quoted_boundary(self) -> None: writer = aiohttp.MultipartWriter(boundary=R"\"") expected = {CONTENT_TYPE: R'multipart/mixed; boundary="\\\""'} assert expected == writer.headers def test_bad_boundary(self) -> None: with pytest.raises(ValueError): aiohttp.MultipartWriter(boundary="тест") with pytest.raises(ValueError): aiohttp.MultipartWriter(boundary="test\n") with pytest.raises(ValueError): aiohttp.MultipartWriter(boundary="X" * 71) def test_default_headers(self, writer: aiohttp.MultipartWriter) -> None: expected = {CONTENT_TYPE: 'multipart/mixed; boundary=":"'} assert expected == writer.headers def test_iter_parts(self, writer: aiohttp.MultipartWriter) -> None: writer.append("foo") writer.append("bar") writer.append("baz") assert 3 == len(list(writer)) def test_append(self, writer: aiohttp.MultipartWriter) -> None: assert 0 == len(writer) writer.append("hello, world!") assert 1 == len(writer) assert isinstance(writer._parts[0][0], payload.Payload) def test_append_with_headers(self, writer: aiohttp.MultipartWriter) -> None: writer.append("hello, world!", {"x-foo": "bar"}) assert 1 == len(writer) assert "x-foo" in writer._parts[0][0].headers assert writer._parts[0][0].headers["x-foo"] == "bar" def test_append_json(self, writer: aiohttp.MultipartWriter) -> None: writer.append_json({"foo": "bar"}) assert 1 == len(writer) part = writer._parts[0][0] assert part.headers[CONTENT_TYPE] == "application/json" def test_append_part(self, writer: aiohttp.MultipartWriter) -> None: part = payload.get_payload("test", headers={CONTENT_TYPE: "text/plain"}) writer.append(part, {CONTENT_TYPE: "test/passed"}) assert 1 == len(writer) part = writer._parts[0][0] assert part.headers[CONTENT_TYPE] == "test/passed" def test_append_json_overrides_content_type( self, writer: aiohttp.MultipartWriter ) -> None: writer.append_json({"foo": "bar"}, {CONTENT_TYPE: "test/passed"}) assert 1 == len(writer) part = writer._parts[0][0] assert part.headers[CONTENT_TYPE] == "test/passed" def test_append_form(self, writer: aiohttp.MultipartWriter) -> None: writer.append_form({"foo": "bar"}, {CONTENT_TYPE: "test/passed"}) assert 1 == len(writer) part = writer._parts[0][0] assert part.headers[CONTENT_TYPE] == "test/passed" def test_append_multipart(self, writer: aiohttp.MultipartWriter) -> None: subwriter = aiohttp.MultipartWriter(boundary=":") subwriter.append_json({"foo": "bar"}) writer.append(subwriter, {CONTENT_TYPE: "test/passed"}) assert 1 == len(writer) part = writer._parts[0][0] assert part.headers[CONTENT_TYPE] == "test/passed" def test_set_content_disposition_after_append(self) -> None: writer = aiohttp.MultipartWriter("form-data") part = writer.append("some-data") part.set_content_disposition("form-data", name="method") assert 'name="method"' in part.headers[CONTENT_DISPOSITION] def test_automatic_content_disposition(self) -> None: writer = aiohttp.MultipartWriter("form-data") writer.append_json(()) part = payload.StringPayload("foo") part.set_content_disposition("form-data", name="second") writer.append_payload(part) writer.append("foo") disps = tuple(p[0].headers[CONTENT_DISPOSITION] for p in writer._parts) assert 'name="section-0"' in disps[0] assert 'name="second"' in disps[1] assert 'name="section-2"' in disps[2] def test_with(self) -> None: with aiohttp.MultipartWriter(boundary=":") as writer: writer.append("foo") writer.append(b"bar") writer.append_json({"baz": True}) assert 3 == len(writer) def test_append_int_not_allowed(self) -> None: with pytest.raises(TypeError): with aiohttp.MultipartWriter(boundary=":") as writer: writer.append(1) def test_append_float_not_allowed(self) -> None: with pytest.raises(TypeError): with aiohttp.MultipartWriter(boundary=":") as writer: writer.append(1.1) def test_append_none_not_allowed(self) -> None: with pytest.raises(TypeError): with aiohttp.MultipartWriter(boundary=":") as writer: writer.append(None) async def test_write_preserves_content_disposition( self, buf: bytearray, stream: AbstractStreamWriter ) -> None: with aiohttp.MultipartWriter(boundary=":") as writer: part = writer.append(b"foo", headers={CONTENT_TYPE: "test/passed"}) part.set_content_disposition("form-data", filename="bug") await writer.write(stream) headers, message = bytes(buf).split(b"\r\n\r\n", 1) assert headers == ( b"--:\r\n" b"Content-Type: test/passed\r\n" b"Content-Length: 3\r\n" b"Content-Disposition:" b' form-data; filename="bug"' ) assert message == b"foo\r\n--:--\r\n" async def test_preserve_content_disposition_header( self, buf: bytearray, stream: AbstractStreamWriter ) -> None: # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 with pathlib.Path(__file__).open("rb") as fobj: with aiohttp.MultipartWriter("form-data", boundary=":") as writer: part = writer.append( fobj, headers={ CONTENT_DISPOSITION: 'attachments; filename="bug.py"', CONTENT_TYPE: "text/python", }, ) await writer.write(stream) assert part.headers[CONTENT_TYPE] == "text/python" assert part.headers[CONTENT_DISPOSITION] == ('attachments; filename="bug.py"') headers, _ = bytes(buf).split(b"\r\n\r\n", 1) assert headers == ( b"--:\r\n" b"Content-Type: text/python\r\n" b'Content-Disposition: attachments; filename="bug.py"' ) async def test_set_content_disposition_override( self, buf: bytearray, stream: AbstractStreamWriter ) -> None: # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 with pathlib.Path(__file__).open("rb") as fobj: with aiohttp.MultipartWriter("form-data", boundary=":") as writer: part = writer.append( fobj, headers={ CONTENT_DISPOSITION: 'attachments; filename="bug.py"', CONTENT_TYPE: "text/python", }, ) await writer.write(stream) assert part.headers[CONTENT_TYPE] == "text/python" assert part.headers[CONTENT_DISPOSITION] == ('attachments; filename="bug.py"') headers, _ = bytes(buf).split(b"\r\n\r\n", 1) assert headers == ( b"--:\r\n" b"Content-Type: text/python\r\n" b'Content-Disposition: attachments; filename="bug.py"' ) async def test_reset_content_disposition_header( self, buf: bytearray, stream: AbstractStreamWriter ) -> None: # https://github.com/aio-libs/aiohttp/pull/3475#issuecomment-451072381 with pathlib.Path(__file__).open("rb") as fobj: with aiohttp.MultipartWriter("form-data", boundary=":") as writer: part = writer.append( fobj, headers={CONTENT_TYPE: "text/plain"}, ) assert CONTENT_DISPOSITION in part.headers part.set_content_disposition("attachments", filename="bug.py") await writer.write(stream) headers, _ = bytes(buf).split(b"\r\n\r\n", 1) assert headers == ( b"--:\r\n" b"Content-Type: text/plain\r\n" b"Content-Disposition:" b' attachments; filename="bug.py"' ) async def test_async_for_reader() -> None: data: tuple[dict[str, str], int, bytes, bytes, bytes] = ( {"test": "passed"}, 42, b"plain text", b"aiohttp\n", b"no epilogue", ) with Stream( b"\r\n".join( [ b"--:", b"Content-Type: application/json", b"", json.dumps(data[0]).encode(), b"--:", b"Content-Type: application/json", b"", json.dumps(data[1]).encode(), b"--:", b'Content-Type: multipart/related; boundary="::"', b"", b"--::", b"Content-Type: text/plain", b"", data[2], b"--::", b'Content-Disposition: attachment; filename="aiohttp"', b"Content-Type: text/plain", b"Content-Length: 28", b"Content-Encoding: gzip", b"", b"\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03K\xcc\xcc\xcf())" b"\xe0\x02\x00\xd6\x90\xe2O\x08\x00\x00\x00", b"--::", b'Content-Type: multipart/related; boundary=":::"', b"", b"--:::", b"Content-Type: text/plain", b"", data[4], b"--:::--", b"--::--", b"", b"--:--", b"", ] ) ) as stream: reader = aiohttp.MultipartReader( headers={CONTENT_TYPE: 'multipart/mixed; boundary=":"'}, content=stream, ) idata = iter(data) async def check(reader: aiohttp.MultipartReader) -> None: async for part in reader: assert part is not None if isinstance(part, aiohttp.BodyPartReader): if part.headers[CONTENT_TYPE] == "application/json": assert next(idata) == (await part.json()) else: assert next(idata) == await part.read(decode=True) else: await check(part) await check(reader) async def test_async_for_bodypart() -> None: h = CIMultiDictProxy[str](CIMultiDict()) with Stream(b"foobarbaz\r\n--:--") as stream: part = aiohttp.BodyPartReader(boundary=b"--:", headers=h, content=stream) async for data in part: assert data == b"foobarbaz" async def test_multipart_writer_reusability( buf: bytearray, stream: mock.Mock, buf2: bytearray, stream2: mock.Mock, writer: aiohttp.MultipartWriter, ) -> None: """Test that MultipartWriter can be written multiple times.""" # Add some parts writer.append("text content") writer.append(b"binary content", {"Content-Type": "application/octet-stream"}) writer.append_json({"key": "value"}) # Test as_bytes multiple times bytes1 = await writer.as_bytes() bytes2 = await writer.as_bytes() bytes3 = await writer.as_bytes() # All as_bytes calls should return identical data assert bytes1 == bytes2 == bytes3 # Verify content is there assert b"text content" in bytes1 assert b"binary content" in bytes1 assert b'"key": "value"' in bytes1 # First write buf.clear() await writer.write(stream) result1 = bytes(buf) # Second write - should produce identical output buf2.clear() await writer.write(stream2) result2 = bytes(buf2) # Results should be identical assert result1 == result2 # Third write to ensure continued reusability buf.clear() await writer.write(stream) result3 = bytes(buf) assert result1 == result3 # as_bytes should still work after writes bytes4 = await writer.as_bytes() assert bytes1 == bytes4 async def test_multipart_writer_reusability_with_io_payloads( buf: bytearray, stream: mock.Mock, buf2: bytearray, stream2: mock.Mock, writer: aiohttp.MultipartWriter, ) -> None: """Test that MultipartWriter with IO payloads can be reused.""" # Create IO objects bytes_io = io.BytesIO(b"bytes io content") string_io = io.StringIO("string io content") # Add IO payloads writer.append(bytes_io, {"Content-Type": "application/octet-stream"}) writer.append(string_io, {"Content-Type": "text/plain"}) # Test as_bytes multiple times bytes1 = await writer.as_bytes() bytes2 = await writer.as_bytes() # All as_bytes calls should return identical data assert bytes1 == bytes2 assert b"bytes io content" in bytes1 assert b"string io content" in bytes1 # First write buf.clear() await writer.write(stream) result1 = bytes(buf) assert b"bytes io content" in result1 assert b"string io content" in result1 # Reset IO objects for reuse bytes_io.seek(0) string_io.seek(0) # Second write buf2.clear() await writer.write(stream2) result2 = bytes(buf2) # Should produce identical results assert result1 == result2 # Test as_bytes after writes (IO objects should auto-reset) bytes3 = await writer.as_bytes() assert bytes1 == bytes3 async def test_body_part_reader_payload_as_bytes() -> None: """Test that BodyPartReaderPayload.as_bytes raises TypeError.""" # Create a mock BodyPartReader headers = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "text/plain"})) protocol = mock.Mock(_reading_paused=False) stream = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) body_part = BodyPartReader(BOUNDARY, headers, stream) # Create the payload payload = BodyPartReaderPayload(body_part) # Test that as_bytes raises TypeError with pytest.raises(TypeError, match="Unable to read body part as bytes"): await payload.as_bytes() # Test that decode also raises TypeError with pytest.raises(TypeError, match="Unable to decode"): payload.decode() async def test_multipart_writer_close_with_exceptions() -> None: """Test that MultipartWriter.close() continues closing all parts even if one raises.""" writer = aiohttp.MultipartWriter() # Create mock payloads # First part will raise during close part1 = mock.Mock() part1.autoclose = False part1.consumed = False part1.close = mock.AsyncMock(side_effect=RuntimeError("Part 1 close failed")) # Second part should still get closed part2 = mock.Mock() part2.autoclose = False part2.consumed = False part2.close = mock.AsyncMock() # Third part with autoclose=True should not be closed part3 = mock.Mock() part3.autoclose = True part3.consumed = False part3.close = mock.AsyncMock() # Fourth part already consumed should not be closed part4 = mock.Mock() part4.autoclose = False part4.consumed = True part4.close = mock.AsyncMock() # Add parts to writer's internal list writer._parts = [ (part1, "", ""), (part2, "", ""), (part3, "", ""), (part4, "", ""), ] # Close the writer - should not raise despite part1 failing await writer.close() # Verify close was called on appropriate parts part1.close.assert_called_once() part2.close.assert_called_once() # Should still be called despite part1 failing part3.close.assert_not_called() # autoclose=True part4.close.assert_not_called() # consumed=True # Verify writer is marked as consumed assert writer._consumed is True # Calling close again should do nothing await writer.close() assert part1.close.call_count == 1 assert part2.close.call_count == 1 ================================================ FILE: tests/test_multipart_helpers.py ================================================ import pytest import aiohttp from aiohttp import content_disposition_filename, parse_content_disposition class TestParseContentDisposition: # http://greenbytes.de/tech/tc2231/ def test_parse_empty(self) -> None: disptype, params = parse_content_disposition(None) assert disptype is None assert {} == params def test_inlonly(self) -> None: disptype, params = parse_content_disposition("inline") assert "inline" == disptype assert {} == params def test_inlonlyquoted(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition('"inline"') assert disptype is None assert {} == params def test_semicolon(self) -> None: disptype, params = parse_content_disposition( 'form-data; name="data"; filename="file ; name.mp4"' ) assert disptype == "form-data" assert params == {"name": "data", "filename": "file ; name.mp4"} def test_inlwithasciifilename(self) -> None: disptype, params = parse_content_disposition('inline; filename="foo.html"') assert "inline" == disptype assert {"filename": "foo.html"} == params def test_inlwithfnattach(self) -> None: disptype, params = parse_content_disposition( 'inline; filename="Not an attachment!"' ) assert "inline" == disptype assert {"filename": "Not an attachment!"} == params def test_attonly(self) -> None: disptype, params = parse_content_disposition("attachment") assert "attachment" == disptype assert {} == params def test_attonlyquoted(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition('"attachment"') assert disptype is None assert {} == params def test_attonlyucase(self) -> None: disptype, params = parse_content_disposition("ATTACHMENT") assert "attachment" == disptype assert {} == params def test_attwithasciifilename(self) -> None: disptype, params = parse_content_disposition('attachment; filename="foo.html"') assert "attachment" == disptype assert {"filename": "foo.html"} == params def test_inlwithasciifilenamepdf(self) -> None: disptype, params = parse_content_disposition('attachment; filename="foo.pdf"') assert "attachment" == disptype assert {"filename": "foo.pdf"} == params def test_attwithasciifilename25(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="0000000000111111111122222"' ) assert "attachment" == disptype assert {"filename": "0000000000111111111122222"} == params def test_attwithasciifilename35(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="00000000001111111111222222222233333"' ) assert "attachment" == disptype assert {"filename": "00000000001111111111222222222233333"} == params def test_attwithasciifnescapedchar(self) -> None: disptype, params = parse_content_disposition( r'attachment; filename="f\oo.html"' ) assert "attachment" == disptype assert {"filename": "foo.html"} == params def test_attwithasciifnescapedquote(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename=""quoting" tested.html"' ) assert "attachment" == disptype assert {"filename": '"quoting" tested.html'} == params @pytest.mark.skip("need more smart parser which respects quoted text") def test_attwithquotedsemicolon(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="Here\'s a semicolon;.html"' ) assert "attachment" == disptype assert {"filename": "Here's a semicolon;.html"} == params def test_attwithfilenameandextparam(self) -> None: disptype, params = parse_content_disposition( 'attachment; foo="bar"; filename="foo.html"' ) assert "attachment" == disptype assert {"filename": "foo.html", "foo": "bar"} == params def test_attwithfilenameandextparamescaped(self) -> None: disptype, params = parse_content_disposition( 'attachment; foo=""\\";filename="foo.html"' ) assert "attachment" == disptype assert {"filename": "foo.html", "foo": '"\\'} == params def test_attwithasciifilenameucase(self) -> None: disptype, params = parse_content_disposition('attachment; FILENAME="foo.html"') assert "attachment" == disptype assert {"filename": "foo.html"} == params def test_attwithasciifilenamenq(self) -> None: disptype, params = parse_content_disposition("attachment; filename=foo.html") assert "attachment" == disptype assert {"filename": "foo.html"} == params def test_attwithtokfncommanq(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename=foo,bar.html" ) assert disptype is None assert {} == params def test_attwithasciifilenamenqs(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename=foo.html ;" ) assert disptype is None assert {} == params def test_attemptyparam(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition("attachment; ;filename=foo") assert disptype is None assert {} == params def test_attwithasciifilenamenqws(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename=foo bar.html" ) assert disptype is None assert {} == params def test_attwithfntokensq(self) -> None: disptype, params = parse_content_disposition("attachment; filename='foo.html'") assert "attachment" == disptype assert {"filename": "'foo.html'"} == params def test_attwithisofnplain(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="foo-ä.html"' ) assert "attachment" == disptype assert {"filename": "foo-ä.html"} == params def test_attwithutf8fnplain(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="foo-ä.html"' ) assert "attachment" == disptype assert {"filename": "foo-ä.html"} == params def test_attwithfnrawpctenca(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="foo-%41.html"' ) assert "attachment" == disptype assert {"filename": "foo-%41.html"} == params def test_attwithfnusingpct(self) -> None: disptype, params = parse_content_disposition('attachment; filename="50%.html"') assert "attachment" == disptype assert {"filename": "50%.html"} == params def test_attwithfnrawpctencaq(self) -> None: disptype, params = parse_content_disposition( r'attachment; filename="foo-%\41.html"' ) assert "attachment" == disptype assert {"filename": r"foo-%41.html"} == params def test_attwithnamepct(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="foo-%41.html"' ) assert "attachment" == disptype assert {"filename": "foo-%41.html"} == params def test_attwithfilenamepctandiso(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="ä-%41.html"' ) assert "attachment" == disptype assert {"filename": "ä-%41.html"} == params def test_attwithfnrawpctenclong(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="foo-%c3%a4-%e2%82%ac.html"' ) assert "attachment" == disptype assert {"filename": "foo-%c3%a4-%e2%82%ac.html"} == params def test_attwithasciifilenamews1(self) -> None: disptype, params = parse_content_disposition('attachment; filename ="foo.html"') assert "attachment" == disptype assert {"filename": "foo.html"} == params def test_attwith2filenames(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( 'attachment; filename="foo.html"; filename="bar.html"' ) assert disptype is None assert {} == params def test_attfnbrokentoken(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename=foo[1](2).html" ) assert disptype is None assert {} == params def test_attfnbrokentokeniso(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename=foo-ä.html" ) assert disptype is None assert {} == params def test_attfnbrokentokenutf(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename=foo-ä.html" ) assert disptype is None assert {} == params def test_attmissingdisposition(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition("filename=foo.html") assert disptype is None assert {} == params def test_attmissingdisposition2(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition("x=y; filename=foo.html") assert disptype is None assert {} == params def test_attmissingdisposition3(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( '"foo; filename=bar;baz"; filename=qux' ) assert disptype is None assert {} == params def test_attmissingdisposition4(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "filename=foo.html, filename=bar.html" ) assert disptype is None assert {} == params def test_emptydisposition(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition("; filename=foo.html") assert disptype is None assert {} == params def test_doublecolon(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( ": inline; attachment; filename=foo.html" ) assert disptype is None assert {} == params def test_attandinline(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "inline; attachment; filename=foo.html" ) assert disptype is None assert {} == params def test_attandinline2(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; inline; filename=foo.html" ) assert disptype is None assert {} == params def test_attbrokenquotedfn(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( 'attachment; filename="foo.html".txt' ) assert disptype is None assert {} == params def test_attbrokenquotedfn2(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition('attachment; filename="bar') assert disptype is None assert {} == params def test_attbrokenquotedfn3(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( 'attachment; filename=foo"bar;baz"qux' ) assert disptype is None assert {} == params def test_attmultinstances(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename=foo.html, attachment; filename=bar.html" ) assert disptype is None assert {} == params def test_attmissingdelim(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; foo=foo filename=bar" ) assert disptype is None assert {} == params def test_attmissingdelim2(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename=bar foo=foo" ) assert disptype is None assert {} == params def test_attmissingdelim3(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition("attachment filename=bar") assert disptype is None assert {} == params def test_attreversed(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "filename=foo.html; attachment" ) assert disptype is None assert {} == params def test_attconfusedparam(self) -> None: disptype, params = parse_content_disposition("attachment; xfilename=foo.html") assert "attachment" == disptype assert {"xfilename": "foo.html"} == params def test_attabspath(self) -> None: disptype, params = parse_content_disposition('attachment; filename="/foo.html"') assert "attachment" == disptype assert {"filename": "foo.html"} == params def test_attabspathwin(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="\\foo.html"' ) assert "attachment" == disptype assert {"filename": "foo.html"} == params def test_attcdate(self) -> None: disptype, params = parse_content_disposition( 'attachment; creation-date="Wed, 12 Feb 1997 16:29:51 -0500"' ) assert "attachment" == disptype assert {"creation-date": "Wed, 12 Feb 1997 16:29:51 -0500"} == params def test_attmdate(self) -> None: disptype, params = parse_content_disposition( 'attachment; modification-date="Wed, 12 Feb 1997 16:29:51 -0500"' ) assert "attachment" == disptype assert {"modification-date": "Wed, 12 Feb 1997 16:29:51 -0500"} == params def test_dispext(self) -> None: disptype, params = parse_content_disposition("foobar") assert "foobar" == disptype assert {} == params def test_dispextbadfn(self) -> None: disptype, params = parse_content_disposition( 'attachment; example="filename=example.txt"' ) assert "attachment" == disptype assert {"example": "filename=example.txt"} == params def test_attwithisofn2231iso(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*=iso-8859-1''foo-%E4.html" ) assert "attachment" == disptype assert {"filename*": "foo-ä.html"} == params def test_attwithfn2231utf8(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*=UTF-8''foo-%c3%a4-%e2%82%ac.html" ) assert "attachment" == disptype assert {"filename*": "foo-ä-€.html"} == params def test_attwithfn2231noc(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*=''foo-%c3%a4-%e2%82%ac.html" ) assert "attachment" == disptype assert {"filename*": "foo-ä-€.html"} == params def test_attwithfn2231utf8comp(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*=UTF-8''foo-a%cc%88.html" ) assert "attachment" == disptype assert {"filename*": "foo-ä.html"} == params @pytest.mark.skip("should raise decoding error: %82 is invalid for latin1") def test_attwithfn2231utf8_bad(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( "attachment; filename*=iso-8859-1''foo-%c3%a4-%e2%82%ac.html" ) assert "attachment" == disptype assert {} == params @pytest.mark.skip("should raise decoding error: %E4 is invalid for utf-8") def test_attwithfn2231iso_bad(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( "attachment; filename*=utf-8''foo-%E4.html" ) assert "attachment" == disptype assert {} == params def test_attwithfn2231ws1(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( "attachment; filename *=UTF-8''foo-%c3%a4.html" ) assert "attachment" == disptype assert {} == params def test_attwithfn2231ws2(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*= UTF-8''foo-%c3%a4.html" ) assert "attachment" == disptype assert {"filename*": "foo-ä.html"} == params def test_attwithfn2231ws3(self) -> None: disptype, params = parse_content_disposition( "attachment; filename* =UTF-8''foo-%c3%a4.html" ) assert "attachment" == disptype assert {"filename*": "foo-ä.html"} == params def test_attwithfn2231quot(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( "attachment; filename*=\"UTF-8''foo-%c3%a4.html\"" ) assert "attachment" == disptype assert {} == params def test_attwithfn2231quot2(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( 'attachment; filename*="foo%20bar.html"' ) assert "attachment" == disptype assert {} == params def test_attwithfn2231singleqmissing(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( "attachment; filename*=UTF-8'foo-%c3%a4.html" ) assert "attachment" == disptype assert {} == params @pytest.mark.skip("urllib.parse.unquote is tolerate to standalone % chars") def test_attwithfn2231nbadpct1(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( "attachment; filename*=UTF-8''foo%" ) assert "attachment" == disptype assert {} == params @pytest.mark.skip("urllib.parse.unquote is tolerate to standalone % chars") def test_attwithfn2231nbadpct2(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( "attachment; filename*=UTF-8''f%oo.html" ) assert "attachment" == disptype assert {} == params def test_attwithfn2231dpct(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*=UTF-8''A-%2541.html" ) assert "attachment" == disptype assert {"filename*": "A-%41.html"} == params def test_attwithfn2231abspathdisguised(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*=UTF-8''%5cfoo.html" ) assert "attachment" == disptype assert {"filename*": "\\foo.html"} == params def test_attfncont(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename*0="foo."; filename*1="html"' ) assert "attachment" == disptype assert {"filename*0": "foo.", "filename*1": "html"} == params def test_attfncontqs(self) -> None: disptype, params = parse_content_disposition( r'attachment; filename*0="foo"; filename*1="\b\a\r.html"' ) assert "attachment" == disptype assert {"filename*0": "foo", "filename*1": "bar.html"} == params def test_attfncontenc(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*0*=UTF-8" + 'foo-%c3%a4; filename*1=".html"' ) assert "attachment" == disptype assert {"filename*0*": "UTF-8foo-%c3%a4", "filename*1": ".html"} == params def test_attfncontlz(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename*0="foo"; filename*01="bar"' ) assert "attachment" == disptype assert {"filename*0": "foo", "filename*01": "bar"} == params def test_attfncontnc(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename*0="foo"; filename*2="bar"' ) assert "attachment" == disptype assert {"filename*0": "foo", "filename*2": "bar"} == params def test_attfnconts1(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename*0="foo."; filename*2="html"' ) assert "attachment" == disptype assert {"filename*0": "foo.", "filename*2": "html"} == params def test_attfncontord(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename*1="bar"; filename*0="foo"' ) assert "attachment" == disptype assert {"filename*0": "foo", "filename*1": "bar"} == params def test_attfnboth(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="foo-ae.html";' + " filename*=UTF-8''foo-%c3%a4.html" ) assert "attachment" == disptype assert {"filename": "foo-ae.html", "filename*": "foo-ä.html"} == params def test_attfnboth2(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*=UTF-8''foo-%c3%a4.html;" + ' filename="foo-ae.html"' ) assert "attachment" == disptype assert {"filename": "foo-ae.html", "filename*": "foo-ä.html"} == params def test_attfnboth3(self) -> None: disptype, params = parse_content_disposition( "attachment; filename*0*=ISO-8859-15''euro-sign%3d%a4;" " filename*=ISO-8859-1''currency-sign%3d%a4" ) assert "attachment" == disptype assert { "filename*": "currency-sign=¤", "filename*0*": "ISO-8859-15''euro-sign%3d%a4", } == params def test_attnewandfn(self) -> None: disptype, params = parse_content_disposition( 'attachment; foobar=x; filename="foo.html"' ) assert "attachment" == disptype assert {"foobar": "x", "filename": "foo.html"} == params def test_attrfc2047token(self) -> None: with pytest.warns(aiohttp.BadContentDispositionHeader): disptype, params = parse_content_disposition( "attachment; filename==?ISO-8859-1?Q?foo-=E4.html?=" ) assert disptype is None assert {} == params def test_attrfc2047quoted(self) -> None: disptype, params = parse_content_disposition( 'attachment; filename="=?ISO-8859-1?Q?foo-=E4.html?="' ) assert "attachment" == disptype assert {"filename": "=?ISO-8859-1?Q?foo-=E4.html?="} == params def test_bad_continuous_param(self) -> None: with pytest.warns(aiohttp.BadContentDispositionParam): disptype, params = parse_content_disposition( "attachment; filename*0=foo bar" ) assert "attachment" == disptype assert {} == params class TestContentDispositionFilename: # http://greenbytes.de/tech/tc2231/ def test_no_filename(self) -> None: assert content_disposition_filename({}) is None assert content_disposition_filename({"foo": "bar"}) is None def test_filename(self) -> None: params = {"filename": "foo.html"} assert "foo.html" == content_disposition_filename(params) def test_filename_ext(self) -> None: params = {"filename*": "файл.html"} assert "файл.html" == content_disposition_filename(params) def test_attfncont(self) -> None: params = {"filename*0": "foo.", "filename*1": "html"} assert "foo.html" == content_disposition_filename(params) def test_attfncontqs(self) -> None: params = {"filename*0": "foo", "filename*1": "bar.html"} assert "foobar.html" == content_disposition_filename(params) def test_attfncontenc(self) -> None: params = {"filename*0*": "UTF-8''foo-%c3%a4", "filename*1": ".html"} assert "foo-ä.html" == content_disposition_filename(params) def test_attfncontlz(self) -> None: params = {"filename*0": "foo", "filename*01": "bar"} assert "foo" == content_disposition_filename(params) def test_attfncontnc(self) -> None: params = {"filename*0": "foo", "filename*2": "bar"} assert "foo" == content_disposition_filename(params) def test_attfnconts1(self) -> None: params = {"filename*1": "foo", "filename*2": "bar"} assert content_disposition_filename(params) is None def test_attfnboth(self) -> None: params = {"filename": "foo-ae.html", "filename*": "foo-ä.html"} assert "foo-ä.html" == content_disposition_filename(params) def test_attfnboth3(self) -> None: params = { "filename*0*": "ISO-8859-15''euro-sign%3d%a4", "filename*": "currency-sign=¤", } assert "currency-sign=¤" == content_disposition_filename(params) def test_attrfc2047quoted(self) -> None: params = {"filename": "=?ISO-8859-1?Q?foo-=E4.html?="} assert "=?ISO-8859-1?Q?foo-=E4.html?=" == content_disposition_filename(params) ================================================ FILE: tests/test_payload.py ================================================ import array import asyncio import io import json import unittest.mock from collections.abc import AsyncIterator, Iterator from io import StringIO from pathlib import Path from typing import TextIO, Union import pytest from multidict import CIMultiDict from aiohttp import payload from aiohttp.abc import AbstractStreamWriter from aiohttp.payload import READ_SIZE class BufferWriter(AbstractStreamWriter): """Test writer that captures written bytes in a buffer.""" def __init__(self) -> None: self.buffer = bytearray() async def write( self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] ) -> None: self.buffer.extend(bytes(chunk)) async def write_eof(self, chunk: bytes = b"") -> None: """No-op for test writer.""" async def drain(self) -> None: """No-op for test writer.""" def enable_compression( self, encoding: str = "deflate", strategy: int | None = None ) -> None: """Compression not implemented for test writer.""" def enable_chunking(self) -> None: """Chunking not implemented for test writer.""" async def write_headers(self, status_line: str, headers: CIMultiDict[str]) -> None: """Headers not captured for payload tests.""" @pytest.fixture(autouse=True) def cleanup( cleanup_payload_pending_file_closes: None, ) -> None: """Ensure all pending file close operations complete during test teardown.""" @pytest.fixture def registry() -> Iterator[payload.PayloadRegistry]: old = payload.PAYLOAD_REGISTRY reg = payload.PAYLOAD_REGISTRY = payload.PayloadRegistry() yield reg payload.PAYLOAD_REGISTRY = old class Payload(payload.Payload): def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: assert False async def write(self, writer: AbstractStreamWriter) -> None: """Dummy write.""" def test_register_type(registry: payload.PayloadRegistry) -> None: class TestProvider: pass payload.register_payload(Payload, TestProvider) p = payload.get_payload(TestProvider()) assert isinstance(p, Payload) def test_register_unsupported_order(registry: payload.PayloadRegistry) -> None: class TestProvider: pass with pytest.raises(ValueError): payload.register_payload( Payload, TestProvider, order=object() # type: ignore[arg-type] ) def test_payload_ctor() -> None: p = Payload("test", encoding="utf-8", filename="test.txt") assert p._value == "test" assert p._encoding == "utf-8" assert p.size is None assert p.filename == "test.txt" assert p.content_type == "text/plain" def test_payload_content_type() -> None: p = Payload("test", headers={"content-type": "application/json"}) assert p.content_type == "application/json" def test_bytes_payload_default_content_type() -> None: p = payload.BytesPayload(b"data") assert p.content_type == "application/octet-stream" def test_bytes_payload_explicit_content_type() -> None: p = payload.BytesPayload(b"data", content_type="application/custom") assert p.content_type == "application/custom" def test_bytes_payload_bad_type() -> None: with pytest.raises(TypeError): payload.BytesPayload(object()) # type: ignore[arg-type] def test_bytes_payload_memoryview_correct_size() -> None: mv = memoryview(array.array("H", [1, 2, 3])) p = payload.BytesPayload(mv) assert p.size == 6 def test_string_payload() -> None: p = payload.StringPayload("test") assert p.encoding == "utf-8" assert p.content_type == "text/plain; charset=utf-8" p = payload.StringPayload("test", encoding="koi8-r") assert p.encoding == "koi8-r" assert p.content_type == "text/plain; charset=koi8-r" p = payload.StringPayload("test", content_type="text/plain; charset=koi8-r") assert p.encoding == "koi8-r" assert p.content_type == "text/plain; charset=koi8-r" def test_string_io_payload() -> None: s = StringIO("ű" * 5000) p = payload.StringIOPayload(s) assert p.encoding == "utf-8" assert p.content_type == "text/plain; charset=utf-8" assert p.size == 10000 def test_async_iterable_payload_default_content_type() -> None: async def gen() -> AsyncIterator[bytes]: yield b"abc" # pragma: no cover p = payload.AsyncIterablePayload(gen()) assert p.content_type == "application/octet-stream" def test_async_iterable_payload_explicit_content_type() -> None: async def gen() -> AsyncIterator[bytes]: yield b"abc" # pragma: no cover p = payload.AsyncIterablePayload(gen(), content_type="application/custom") assert p.content_type == "application/custom" def test_async_iterable_payload_not_async_iterable() -> None: with pytest.raises(TypeError): payload.AsyncIterablePayload(object()) # type: ignore[arg-type] class MockStreamWriter(AbstractStreamWriter): """Mock stream writer for testing payload writes.""" def __init__(self) -> None: self.written: list[bytes] = [] async def write( self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] ) -> None: """Store the chunk in the written list.""" self.written.append(bytes(chunk)) async def write_eof(self, chunk: bytes | None = None) -> None: """write_eof implementation - no-op for tests.""" async def drain(self) -> None: """Drain implementation - no-op for tests.""" def enable_compression( self, encoding: str = "deflate", strategy: int | None = None ) -> None: """Enable compression - no-op for tests.""" def enable_chunking(self) -> None: """Enable chunking - no-op for tests.""" async def write_headers(self, status_line: str, headers: CIMultiDict[str]) -> None: """Write headers - no-op for tests.""" def get_written_bytes(self) -> bytes: """Return all written bytes as a single bytes object.""" return b"".join(self.written) async def test_bytes_payload_write_with_length_no_limit() -> None: """Test BytesPayload writing with no content length limit.""" data = b"0123456789" p = payload.BytesPayload(data) writer = MockStreamWriter() await p.write_with_length(writer, None) assert writer.get_written_bytes() == data assert len(writer.get_written_bytes()) == 10 async def test_bytes_payload_write_with_length_exact() -> None: """Test BytesPayload writing with exact content length.""" data = b"0123456789" p = payload.BytesPayload(data) writer = MockStreamWriter() await p.write_with_length(writer, 10) assert writer.get_written_bytes() == data assert len(writer.get_written_bytes()) == 10 async def test_bytes_payload_write_with_length_truncated() -> None: """Test BytesPayload writing with truncated content length.""" data = b"0123456789" p = payload.BytesPayload(data) writer = MockStreamWriter() await p.write_with_length(writer, 5) assert writer.get_written_bytes() == b"01234" assert len(writer.get_written_bytes()) == 5 async def test_iobase_payload_write_with_length_no_limit() -> None: """Test IOBasePayload writing with no content length limit.""" data = b"0123456789" p = payload.IOBasePayload(io.BytesIO(data)) writer = MockStreamWriter() await p.write_with_length(writer, None) assert writer.get_written_bytes() == data assert len(writer.get_written_bytes()) == 10 async def test_iobase_payload_write_with_length_exact() -> None: """Test IOBasePayload writing with exact content length.""" data = b"0123456789" p = payload.IOBasePayload(io.BytesIO(data)) writer = MockStreamWriter() await p.write_with_length(writer, 10) assert writer.get_written_bytes() == data assert len(writer.get_written_bytes()) == 10 async def test_iobase_payload_write_with_length_truncated() -> None: """Test IOBasePayload writing with truncated content length.""" data = b"0123456789" p = payload.IOBasePayload(io.BytesIO(data)) writer = MockStreamWriter() await p.write_with_length(writer, 5) assert writer.get_written_bytes() == b"01234" assert len(writer.get_written_bytes()) == 5 async def test_bytesio_payload_write_with_length_no_limit() -> None: """Test BytesIOPayload writing with no content length limit.""" data = b"0123456789" p = payload.BytesIOPayload(io.BytesIO(data)) writer = MockStreamWriter() await p.write_with_length(writer, None) assert writer.get_written_bytes() == data assert len(writer.get_written_bytes()) == 10 async def test_bytesio_payload_write_with_length_exact() -> None: """Test BytesIOPayload writing with exact content length.""" data = b"0123456789" p = payload.BytesIOPayload(io.BytesIO(data)) writer = MockStreamWriter() await p.write_with_length(writer, 10) assert writer.get_written_bytes() == data assert len(writer.get_written_bytes()) == 10 async def test_bytesio_payload_write_with_length_truncated() -> None: """Test BytesIOPayload writing with truncated content length.""" data = b"0123456789" payload_bytesio = payload.BytesIOPayload(io.BytesIO(data)) writer = MockStreamWriter() await payload_bytesio.write_with_length(writer, 5) assert writer.get_written_bytes() == b"01234" assert len(writer.get_written_bytes()) == 5 async def test_bytesio_payload_write_with_length_remaining_zero() -> None: """Test BytesIOPayload with content_length smaller than first read chunk.""" data = b"0123456789" * 10 # 100 bytes bio = io.BytesIO(data) payload_bytesio = payload.BytesIOPayload(bio) writer = MockStreamWriter() # Mock the read method to return smaller chunks original_read = bio.read read_calls = 0 def mock_read(size: int | None = None) -> bytes: nonlocal read_calls read_calls += 1 if read_calls == 1: # First call: return 3 bytes (less than content_length=5) return original_read(3) else: # Subsequent calls return remaining data normally return original_read(size) with unittest.mock.patch.object(bio, "read", mock_read): await payload_bytesio.write_with_length(writer, 5) assert len(writer.get_written_bytes()) == 5 assert writer.get_written_bytes() == b"01234" async def test_bytesio_payload_large_data_multiple_chunks() -> None: """Test BytesIOPayload with large data requiring multiple read chunks.""" chunk_size = 2**16 # 64KB (READ_SIZE) data = b"x" * (chunk_size + 1000) # Slightly larger than READ_SIZE payload_bytesio = payload.BytesIOPayload(io.BytesIO(data)) writer = MockStreamWriter() await payload_bytesio.write_with_length(writer, None) assert writer.get_written_bytes() == data assert len(writer.get_written_bytes()) == chunk_size + 1000 async def test_bytesio_payload_remaining_bytes_exhausted() -> None: """Test BytesIOPayload when remaining_bytes becomes <= 0.""" data = b"0123456789abcdef" * 1000 # 16000 bytes payload_bytesio = payload.BytesIOPayload(io.BytesIO(data)) writer = MockStreamWriter() await payload_bytesio.write_with_length(writer, 8000) # Exactly half the data written = writer.get_written_bytes() assert len(written) == 8000 assert written == data[:8000] async def test_iobase_payload_exact_chunk_size_limit() -> None: """Test IOBasePayload with content length matching exactly one read chunk.""" chunk_size = 2**16 # 65536 bytes (READ_SIZE) data = b"x" * chunk_size + b"extra" # Slightly larger than one read chunk p = payload.IOBasePayload(io.BytesIO(data)) writer = MockStreamWriter() await p.write_with_length(writer, chunk_size) written = writer.get_written_bytes() assert len(written) == chunk_size assert written == data[:chunk_size] async def test_iobase_payload_reads_in_chunks() -> None: """Test IOBasePayload reads data in chunks of READ_SIZE, not all at once.""" # Create a large file that's multiple times larger than READ_SIZE large_data = b"x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 bytes # Mock the file-like object to track read calls mock_file = unittest.mock.Mock(spec=io.BytesIO) mock_file.tell.return_value = 0 mock_file.fileno.side_effect = AttributeError # Make size return None # Track the sizes of read() calls read_sizes = [] def mock_read(size: int) -> bytes: read_sizes.append(size) # Return data based on how many times read was called call_count = len(read_sizes) if call_count == 1: return large_data[:size] elif call_count == 2: return large_data[READ_SIZE : READ_SIZE + size] elif call_count == 3: return large_data[READ_SIZE * 2 : READ_SIZE * 2 + size] else: return large_data[READ_SIZE * 3 :] mock_file.read.side_effect = mock_read payload_obj = payload.IOBasePayload(mock_file) writer = MockStreamWriter() # Write with a large content_length await payload_obj.write_with_length(writer, len(large_data)) # Verify that reads were limited to READ_SIZE assert len(read_sizes) > 1 # Should have multiple reads for read_size in read_sizes: assert ( read_size <= READ_SIZE ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" async def test_iobase_payload_large_content_length() -> None: """Test IOBasePayload with very large content_length doesn't read all at once.""" data = b"x" * (READ_SIZE + 1000) # Create a custom file-like object that tracks read sizes class TrackingBytesIO(io.BytesIO): def __init__(self, data: bytes) -> None: super().__init__(data) self.read_sizes: list[int] = [] def read(self, size: int | None = -1) -> bytes: self.read_sizes.append(size if size is not None else -1) return super().read(size) tracking_file = TrackingBytesIO(data) payload_obj = payload.IOBasePayload(tracking_file) writer = MockStreamWriter() # Write with a very large content_length (simulating the bug scenario) large_content_length = 10 * 1024 * 1024 # 10MB await payload_obj.write_with_length(writer, large_content_length) # Verify no single read exceeded READ_SIZE for read_size in tracking_file.read_sizes: assert ( read_size <= READ_SIZE ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" # Verify the correct amount of data was written assert writer.get_written_bytes() == data async def test_textio_payload_reads_in_chunks() -> None: """Test TextIOPayload reads data in chunks of READ_SIZE, not all at once.""" # Create a large text file that's multiple times larger than READ_SIZE large_text = "x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 chars # Mock the file-like object to track read calls mock_file = unittest.mock.Mock(spec=io.StringIO) mock_file.tell.return_value = 0 mock_file.fileno.side_effect = AttributeError # Make size return None mock_file.encoding = "utf-8" # Track the sizes of read() calls read_sizes = [] def mock_read(size: int) -> str: read_sizes.append(size) # Return data based on how many times read was called call_count = len(read_sizes) if call_count == 1: return large_text[:size] elif call_count == 2: return large_text[READ_SIZE : READ_SIZE + size] elif call_count == 3: return large_text[READ_SIZE * 2 : READ_SIZE * 2 + size] else: return large_text[READ_SIZE * 3 :] mock_file.read.side_effect = mock_read payload_obj = payload.TextIOPayload(mock_file) writer = MockStreamWriter() # Write with a large content_length await payload_obj.write_with_length(writer, len(large_text.encode("utf-8"))) # Verify that reads were limited to READ_SIZE assert len(read_sizes) > 1 # Should have multiple reads for read_size in read_sizes: assert ( read_size <= READ_SIZE ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" async def test_textio_payload_large_content_length() -> None: """Test TextIOPayload with very large content_length doesn't read all at once.""" text_data = "x" * (READ_SIZE + 1000) # Create a custom file-like object that tracks read sizes class TrackingStringIO(io.StringIO): def __init__(self, data: str) -> None: super().__init__(data) self.read_sizes: list[int] = [] def read(self, size: int | None = -1) -> str: self.read_sizes.append(size if size is not None else -1) return super().read(size) tracking_file = TrackingStringIO(text_data) payload_obj = payload.TextIOPayload(tracking_file) writer = MockStreamWriter() # Write with a very large content_length (simulating the bug scenario) large_content_length = 10 * 1024 * 1024 # 10MB await payload_obj.write_with_length(writer, large_content_length) # Verify no single read exceeded READ_SIZE for read_size in tracking_file.read_sizes: assert ( read_size <= READ_SIZE ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" # Verify the correct amount of data was written assert writer.get_written_bytes() == text_data.encode("utf-8") async def test_async_iterable_payload_write_with_length_no_limit() -> None: """Test AsyncIterablePayload writing with no content length limit.""" async def gen() -> AsyncIterator[bytes]: yield b"0123" yield b"4567" yield b"89" p = payload.AsyncIterablePayload(gen()) writer = MockStreamWriter() await p.write_with_length(writer, None) assert writer.get_written_bytes() == b"0123456789" assert len(writer.get_written_bytes()) == 10 async def test_async_iterable_payload_write_with_length_exact() -> None: """Test AsyncIterablePayload writing with exact content length.""" async def gen() -> AsyncIterator[bytes]: yield b"0123" yield b"4567" yield b"89" p = payload.AsyncIterablePayload(gen()) writer = MockStreamWriter() await p.write_with_length(writer, 10) assert writer.get_written_bytes() == b"0123456789" assert len(writer.get_written_bytes()) == 10 async def test_async_iterable_payload_write_with_length_truncated_mid_chunk() -> None: """Test AsyncIterablePayload writing with content length truncating mid-chunk.""" async def gen() -> AsyncIterator[bytes]: yield b"0123" yield b"4567" yield b"89" # pragma: no cover p = payload.AsyncIterablePayload(gen()) writer = MockStreamWriter() await p.write_with_length(writer, 6) assert writer.get_written_bytes() == b"012345" assert len(writer.get_written_bytes()) == 6 async def test_async_iterable_payload_write_with_length_truncated_at_chunk() -> None: """Test AsyncIterablePayload writing with content length truncating at chunk boundary.""" async def gen() -> AsyncIterator[bytes]: yield b"0123" yield b"4567" # pragma: no cover yield b"89" # pragma: no cover p = payload.AsyncIterablePayload(gen()) writer = MockStreamWriter() await p.write_with_length(writer, 4) assert writer.get_written_bytes() == b"0123" assert len(writer.get_written_bytes()) == 4 async def test_bytes_payload_backwards_compatibility() -> None: """Test BytesPayload.write() backwards compatibility delegates to write_with_length().""" p = payload.BytesPayload(b"1234567890") writer = MockStreamWriter() await p.write(writer) assert writer.get_written_bytes() == b"1234567890" async def test_textio_payload_with_encoding() -> None: """Test TextIOPayload reading with encoding and size constraints.""" data = io.StringIO("hello world") p = payload.TextIOPayload(data, encoding="utf-8") writer = MockStreamWriter() await p.write_with_length(writer, 8) # Should write exactly 8 bytes: "hello wo" assert writer.get_written_bytes() == b"hello wo" async def test_textio_payload_as_bytes() -> None: """Test TextIOPayload.as_bytes method with different encodings.""" # Test with UTF-8 encoding data = io.StringIO("Hello 世界") p = payload.TextIOPayload(data, encoding="utf-8") # Test as_bytes() method result = await p.as_bytes() assert result == "Hello 世界".encode() # Test that position is restored for multiple reads result2 = await p.as_bytes() assert result2 == "Hello 世界".encode() # Test with different encoding parameter (should use instance encoding) result3 = await p.as_bytes(encoding="latin-1") assert result3 == "Hello 世界".encode() # Should still use utf-8 # Test with different encoding in payload data2 = io.StringIO("Hello World") p2 = payload.TextIOPayload(data2, encoding="latin-1") result4 = await p2.as_bytes() assert result4 == b"Hello World" # latin-1 encoding # Test with no explicit encoding (defaults to utf-8) data3 = io.StringIO("Test データ") p3 = payload.TextIOPayload(data3) result5 = await p3.as_bytes() assert result5 == "Test データ".encode() # Test with encoding errors parameter data4 = io.StringIO("Test") p4 = payload.TextIOPayload(data4, encoding="ascii") result6 = await p4.as_bytes(errors="strict") assert result6 == b"Test" async def test_bytesio_payload_backwards_compatibility() -> None: """Test BytesIOPayload.write() backwards compatibility delegates to write_with_length().""" data = io.BytesIO(b"test data") p = payload.BytesIOPayload(data) writer = MockStreamWriter() await p.write(writer) assert writer.get_written_bytes() == b"test data" async def test_async_iterable_payload_backwards_compatibility() -> None: """Test AsyncIterablePayload.write() backwards compatibility delegates to write_with_length().""" async def gen() -> AsyncIterator[bytes]: yield b"chunk1" yield b"chunk2" # pragma: no cover p = payload.AsyncIterablePayload(gen()) writer = MockStreamWriter() await p.write(writer) assert writer.get_written_bytes() == b"chunk1chunk2" async def test_async_iterable_payload_with_none_iterator() -> None: """Test AsyncIterablePayload with None iterator returns early without writing.""" async def gen() -> AsyncIterator[bytes]: yield b"test" # pragma: no cover p = payload.AsyncIterablePayload(gen()) # Manually set _iter to None to test the guard clause p._iter = None writer = MockStreamWriter() # Should return early without writing anything await p.write_with_length(writer, 10) assert writer.get_written_bytes() == b"" async def test_async_iterable_payload_caching() -> None: """Test AsyncIterablePayload caching behavior.""" async def gen() -> AsyncIterator[bytes]: yield b"Hello" yield b" " yield b"World" p = payload.AsyncIterablePayload(gen()) # First call to as_bytes should consume iterator and cache result1 = await p.as_bytes() assert result1 == b"Hello World" assert p._iter is None # Iterator exhausted assert p._cached_chunks == [b"Hello", b" ", b"World"] # Chunks cached assert p._consumed is False # Not marked as consumed to allow reuse # Second call should use cache result2 = await p.as_bytes() assert result2 == b"Hello World" assert p._cached_chunks == [b"Hello", b" ", b"World"] # Still cached # decode should work with cached chunks decoded = p.decode() assert decoded == "Hello World" # write_with_length should use cached chunks writer = MockStreamWriter() await p.write_with_length(writer, None) assert writer.get_written_bytes() == b"Hello World" # write_with_length with limit should respect it writer2 = MockStreamWriter() await p.write_with_length(writer2, 5) assert writer2.get_written_bytes() == b"Hello" async def test_async_iterable_payload_decode_without_cache() -> None: """Test AsyncIterablePayload decode raises error without cache.""" async def gen() -> AsyncIterator[bytes]: yield b"test" p = payload.AsyncIterablePayload(gen()) # decode should raise without cache with pytest.raises(TypeError) as excinfo: p.decode() assert "Unable to decode - content not cached" in str(excinfo.value) # After as_bytes, decode should work await p.as_bytes() assert p.decode() == "test" async def test_async_iterable_payload_write_then_cache() -> None: """Test AsyncIterablePayload behavior when written before caching.""" async def gen() -> AsyncIterator[bytes]: yield b"Hello" yield b"World" p = payload.AsyncIterablePayload(gen()) # First write without caching (streaming) writer1 = MockStreamWriter() await p.write_with_length(writer1, None) assert writer1.get_written_bytes() == b"HelloWorld" assert p._iter is None # Iterator exhausted assert p._cached_chunks is None # No cache created assert p._consumed is True # Marked as consumed # Subsequent operations should handle exhausted iterator result = await p.as_bytes() assert result == b"" # Empty since iterator exhausted without cache # Write should also be empty writer2 = MockStreamWriter() await p.write_with_length(writer2, None) assert writer2.get_written_bytes() == b"" async def test_bytes_payload_reusability() -> None: """Test that BytesPayload can be written and read multiple times.""" data = b"test payload data" p = payload.BytesPayload(data) # First write_with_length writer1 = MockStreamWriter() await p.write_with_length(writer1, None) assert writer1.get_written_bytes() == data # Second write_with_length (simulating redirect) writer2 = MockStreamWriter() await p.write_with_length(writer2, None) assert writer2.get_written_bytes() == data # Write with partial length writer3 = MockStreamWriter() await p.write_with_length(writer3, 5) assert writer3.get_written_bytes() == b"test " # Test as_bytes multiple times bytes1 = await p.as_bytes() bytes2 = await p.as_bytes() bytes3 = await p.as_bytes() assert bytes1 == bytes2 == bytes3 == data async def test_string_payload_reusability() -> None: """Test that StringPayload can be written and read multiple times.""" text = "test string data" expected_bytes = text.encode("utf-8") p = payload.StringPayload(text) # First write_with_length writer1 = MockStreamWriter() await p.write_with_length(writer1, None) assert writer1.get_written_bytes() == expected_bytes # Second write_with_length (simulating redirect) writer2 = MockStreamWriter() await p.write_with_length(writer2, None) assert writer2.get_written_bytes() == expected_bytes # Write with partial length writer3 = MockStreamWriter() await p.write_with_length(writer3, 5) assert writer3.get_written_bytes() == b"test " # Test as_bytes multiple times bytes1 = await p.as_bytes() bytes2 = await p.as_bytes() bytes3 = await p.as_bytes() assert bytes1 == bytes2 == bytes3 == expected_bytes async def test_bytes_io_payload_reusability() -> None: """Test that BytesIOPayload can be written and read multiple times.""" data = b"test bytesio payload" bytes_io = io.BytesIO(data) p = payload.BytesIOPayload(bytes_io) # First write_with_length writer1 = MockStreamWriter() await p.write_with_length(writer1, None) assert writer1.get_written_bytes() == data # Second write_with_length (simulating redirect) writer2 = MockStreamWriter() await p.write_with_length(writer2, None) assert writer2.get_written_bytes() == data # Write with partial length writer3 = MockStreamWriter() await p.write_with_length(writer3, 5) assert writer3.get_written_bytes() == b"test " # Test as_bytes multiple times bytes1 = await p.as_bytes() bytes2 = await p.as_bytes() bytes3 = await p.as_bytes() assert bytes1 == bytes2 == bytes3 == data async def test_string_io_payload_reusability() -> None: """Test that StringIOPayload can be written and read multiple times.""" text = "test stringio payload" expected_bytes = text.encode("utf-8") string_io = io.StringIO(text) p = payload.StringIOPayload(string_io) # Note: StringIOPayload reads all content in __init__ and becomes a StringPayload # So it should be fully reusable # First write_with_length writer1 = MockStreamWriter() await p.write_with_length(writer1, None) assert writer1.get_written_bytes() == expected_bytes # Second write_with_length (simulating redirect) writer2 = MockStreamWriter() await p.write_with_length(writer2, None) assert writer2.get_written_bytes() == expected_bytes # Write with partial length writer3 = MockStreamWriter() await p.write_with_length(writer3, 5) assert writer3.get_written_bytes() == b"test " # Test as_bytes multiple times bytes1 = await p.as_bytes() bytes2 = await p.as_bytes() bytes3 = await p.as_bytes() assert bytes1 == bytes2 == bytes3 == expected_bytes async def test_buffered_reader_payload_reusability() -> None: """Test that BufferedReaderPayload can be written and read multiple times.""" data = b"test buffered reader payload" buffer = io.BufferedReader(io.BytesIO(data)) p = payload.BufferedReaderPayload(buffer) # First write_with_length writer1 = MockStreamWriter() await p.write_with_length(writer1, None) assert writer1.get_written_bytes() == data # Second write_with_length (simulating redirect) writer2 = MockStreamWriter() await p.write_with_length(writer2, None) assert writer2.get_written_bytes() == data # Write with partial length writer3 = MockStreamWriter() await p.write_with_length(writer3, 5) assert writer3.get_written_bytes() == b"test " # Test as_bytes multiple times bytes1 = await p.as_bytes() bytes2 = await p.as_bytes() bytes3 = await p.as_bytes() assert bytes1 == bytes2 == bytes3 == data async def test_async_iterable_payload_reusability_with_cache() -> None: """Test that AsyncIterablePayload can be reused when cached via as_bytes.""" async def gen() -> AsyncIterator[bytes]: yield b"async " yield b"iterable " yield b"payload" expected_data = b"async iterable payload" p = payload.AsyncIterablePayload(gen()) # First call to as_bytes should cache the data bytes1 = await p.as_bytes() assert bytes1 == expected_data assert p._cached_chunks is not None assert p._iter is None # Iterator exhausted # Subsequent as_bytes calls should use cache bytes2 = await p.as_bytes() bytes3 = await p.as_bytes() assert bytes1 == bytes2 == bytes3 == expected_data # Now writes should also use the cached data writer1 = MockStreamWriter() await p.write_with_length(writer1, None) assert writer1.get_written_bytes() == expected_data # Second write should also work writer2 = MockStreamWriter() await p.write_with_length(writer2, None) assert writer2.get_written_bytes() == expected_data # Write with partial length writer3 = MockStreamWriter() await p.write_with_length(writer3, 5) assert writer3.get_written_bytes() == b"async" async def test_async_iterable_payload_no_reuse_without_cache() -> None: """Test that AsyncIterablePayload cannot be reused without caching.""" async def gen() -> AsyncIterator[bytes]: yield b"test " yield b"data" p = payload.AsyncIterablePayload(gen()) # First write exhausts the iterator writer1 = MockStreamWriter() await p.write_with_length(writer1, None) assert writer1.get_written_bytes() == b"test data" assert p._iter is None # Iterator exhausted assert p._consumed is True # Second write should produce empty result writer2 = MockStreamWriter() await p.write_with_length(writer2, None) assert writer2.get_written_bytes() == b"" async def test_bytes_io_payload_close_does_not_close_io() -> None: """Test that BytesIOPayload close() does not close the underlying BytesIO.""" bytes_io = io.BytesIO(b"data") bytes_io_payload = payload.BytesIOPayload(bytes_io) # Close the payload await bytes_io_payload.close() # BytesIO should NOT be closed assert not bytes_io.closed # Can still write after close writer = MockStreamWriter() await bytes_io_payload.write_with_length(writer, None) assert writer.get_written_bytes() == b"data" async def test_custom_payload_backwards_compat_as_bytes() -> None: """Test backwards compatibility for custom Payload that only implements decode().""" class LegacyPayload(payload.Payload): """A custom payload that only implements decode() like old code might do.""" def __init__(self, data: str) -> None: super().__init__(data, headers=CIMultiDict()) self._data = data def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: """Custom decode implementation.""" return self._data async def write(self, writer: AbstractStreamWriter) -> None: """Write implementation which is a no-op for this test.""" # Create instance with test data p = LegacyPayload("Hello, World!") # Test that as_bytes() works even though it's not explicitly implemented # The base class should call decode() and encode the result result = await p.as_bytes() assert result == b"Hello, World!" # Test with different text p2 = LegacyPayload("Test with special chars: café") result_utf8 = await p2.as_bytes(encoding="utf-8") assert result_utf8 == "Test with special chars: café".encode() # Test that decode() still works as expected assert p.decode() == "Hello, World!" assert p2.decode() == "Test with special chars: café" async def test_custom_payload_with_encoding_backwards_compat() -> None: """Test custom Payload with encoding set uses instance encoding for as_bytes().""" class EncodedPayload(payload.Payload): """A custom payload with specific encoding.""" def __init__(self, data: str, encoding: str) -> None: super().__init__(data, headers=CIMultiDict(), encoding=encoding) self._data = data def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: """Custom decode implementation.""" return self._data async def write(self, writer: AbstractStreamWriter) -> None: """Write implementation is a no-op.""" # Create instance with specific encoding p = EncodedPayload("Test data", encoding="latin-1") # as_bytes() should use the instance encoding (latin-1) not the default utf-8 result = await p.as_bytes() assert result == b"Test data" # ASCII chars are same in latin-1 # Test with non-ASCII that differs between encodings p2 = EncodedPayload("café", encoding="latin-1") result_latin1 = await p2.as_bytes() assert result_latin1 == "café".encode("latin-1") assert result_latin1 != "café".encode() # Should be different bytes async def test_iobase_payload_close_idempotent() -> None: """Test that IOBasePayload.close() is idempotent and covers the _consumed check.""" file_like = io.BytesIO(b"test data") p = payload.IOBasePayload(file_like) # First close should set _consumed to True await p.close() assert p._consumed is True # Second close should be a no-op due to _consumed check (line 621) await p.close() assert p._consumed is True def test_iobase_payload_decode() -> None: """Test IOBasePayload.decode() returns correct string.""" # Test with UTF-8 encoded text text = "Hello, 世界! 🌍" file_like = io.BytesIO(text.encode("utf-8")) p = payload.IOBasePayload(file_like) # decode() should return the original string assert p.decode() == text # Test with different encoding latin1_text = "café" file_like2 = io.BytesIO(latin1_text.encode("latin-1")) p2 = payload.IOBasePayload(file_like2) assert p2.decode("latin-1") == latin1_text # Test that file position is restored file_like3 = io.BytesIO(b"test data") file_like3.read(4) # Move position forward p3 = payload.IOBasePayload(file_like3) # decode() should read from the stored start position (4) assert p3.decode() == " data" def test_bytes_payload_size() -> None: """Test BytesPayload.size property returns correct byte length.""" # Test with bytes bp = payload.BytesPayload(b"Hello World") assert bp.size == 11 # Test with empty bytes bp_empty = payload.BytesPayload(b"") assert bp_empty.size == 0 # Test with bytearray ba = bytearray(b"Hello World") bp_array = payload.BytesPayload(ba) assert bp_array.size == 11 def test_string_payload_size() -> None: """Test StringPayload.size property with different encodings.""" # Test ASCII string with default UTF-8 encoding sp = payload.StringPayload("Hello World") assert sp.size == 11 # Test Unicode string with default UTF-8 encoding unicode_str = "Hello 世界" sp_unicode = payload.StringPayload(unicode_str) assert sp_unicode.size == len(unicode_str.encode("utf-8")) # Test with UTF-16 encoding sp_utf16 = payload.StringPayload("Hello World", encoding="utf-16") assert sp_utf16.size == len("Hello World".encode("utf-16")) # Test with latin-1 encoding sp_latin1 = payload.StringPayload("café", encoding="latin-1") assert sp_latin1.size == len("café".encode("latin-1")) def test_string_io_payload_size() -> None: """Test StringIOPayload.size property.""" # Test normal string sio = StringIO("Hello World") siop = payload.StringIOPayload(sio) assert siop.size == 11 # Test Unicode string sio_unicode = StringIO("Hello 世界") siop_unicode = payload.StringIOPayload(sio_unicode) assert siop_unicode.size == len("Hello 世界".encode()) # Test with custom encoding sio_custom = StringIO("Hello") siop_custom = payload.StringIOPayload(sio_custom, encoding="utf-16") assert siop_custom.size == len("Hello".encode("utf-16")) # Test with emoji to ensure correct byte count sio_emoji = StringIO("Hello 👋🌍") siop_emoji = payload.StringIOPayload(sio_emoji) assert siop_emoji.size == len("Hello 👋🌍".encode()) # Verify it's not the string length assert siop_emoji.size != len("Hello 👋🌍") def test_all_string_payloads_size_is_bytes() -> None: """Test that all string-like payload classes report size in bytes, not string length.""" # Test string with multibyte characters test_str = "Hello 👋 世界 🌍" # Contains emoji and Chinese characters # StringPayload sp = payload.StringPayload(test_str) assert sp.size == len(test_str.encode("utf-8")) assert sp.size != len(test_str) # Ensure it's not string length # StringIOPayload sio = StringIO(test_str) siop = payload.StringIOPayload(sio) assert siop.size == len(test_str.encode("utf-8")) assert siop.size != len(test_str) # Test with different encoding sp_utf16 = payload.StringPayload(test_str, encoding="utf-16") assert sp_utf16.size == len(test_str.encode("utf-16")) assert sp_utf16.size != sp.size # Different encoding = different size # JsonPayload (which extends BytesPayload) json_data = {"message": test_str} jp = payload.JsonPayload(json_data) # JSON escapes Unicode, so we need to check the actual encoded size json_str = json.dumps(json_data) assert jp.size == len(json_str.encode("utf-8")) # Test JsonPayload with ensure_ascii=False to get actual UTF-8 encoding jp_utf8 = payload.JsonPayload( json_data, dumps=lambda x: json.dumps(x, ensure_ascii=False) ) json_str_utf8 = json.dumps(json_data, ensure_ascii=False) assert jp_utf8.size == len(json_str_utf8.encode("utf-8")) assert jp_utf8.size != len( json_str_utf8 ) # Now it's different due to multibyte chars def test_bytes_io_payload_size() -> None: """Test BytesIOPayload.size property.""" # Test normal bytes bio = io.BytesIO(b"Hello World") biop = payload.BytesIOPayload(bio) assert biop.size == 11 # Test empty BytesIO bio_empty = io.BytesIO(b"") biop_empty = payload.BytesIOPayload(bio_empty) assert biop_empty.size == 0 # Test with position not at start bio_pos = io.BytesIO(b"Hello World") bio_pos.seek(5) biop_pos = payload.BytesIOPayload(bio_pos) assert biop_pos.size == 6 # Size should be from position to end def test_json_payload_size() -> None: """Test JsonPayload.size property.""" # Test simple dict data = {"hello": "world"} jp = payload.JsonPayload(data) expected_json = json.dumps(data) # Use actual json.dumps output assert jp.size == len(expected_json.encode("utf-8")) # Test with Unicode data_unicode = {"message": "Hello 世界"} jp_unicode = payload.JsonPayload(data_unicode) expected_unicode = json.dumps(data_unicode) assert jp_unicode.size == len(expected_unicode.encode("utf-8")) # Test with custom encoding data_custom = {"test": "data"} jp_custom = payload.JsonPayload(data_custom, encoding="utf-16") expected_custom = json.dumps(data_custom) assert jp_custom.size == len(expected_custom.encode("utf-16")) def test_json_bytes_payload() -> None: """Test JsonBytesPayload with a bytes-returning encoder.""" data = {"hello": "world"} # Test with standard library encoder jp = payload.JsonBytesPayload(data, dumps=lambda x: json.dumps(x).encode("utf-8")) expected = json.dumps(data).encode("utf-8") assert jp.size == len(expected) # Test with custom bytes-returning encoder (compact separators) jp_custom = payload.JsonBytesPayload( data, dumps=lambda x: json.dumps(x, separators=(",", ":")).encode("utf-8") ) expected_custom = json.dumps(data, separators=(",", ":")).encode("utf-8") assert jp_custom.size == len(expected_custom) def test_json_bytes_payload_content_type() -> None: """Test JsonBytesPayload content_type.""" data = {"test": "data"} # Default content type jp = payload.JsonBytesPayload(data, dumps=lambda x: json.dumps(x).encode("utf-8")) assert jp.content_type == "application/json" # Custom content type jp_custom = payload.JsonBytesPayload( data, dumps=lambda x: json.dumps(x).encode("utf-8"), content_type="application/vnd.api+json", ) assert jp_custom.content_type == "application/vnd.api+json" async def test_text_io_payload_size_matches_file_encoding(tmp_path: Path) -> None: """Test TextIOPayload.size when file encoding matches payload encoding.""" # Create UTF-8 file utf8_file = tmp_path / "test_utf8.txt" content = "Hello 世界" # Write file in executor loop = asyncio.get_running_loop() await loop.run_in_executor(None, utf8_file.write_text, content, "utf-8") # Open file in executor def open_file() -> TextIO: return open(utf8_file, encoding="utf-8") f = await loop.run_in_executor(None, open_file) try: tiop = payload.TextIOPayload(f) # Size should match the actual UTF-8 encoded size assert tiop.size == len(content.encode("utf-8")) finally: await loop.run_in_executor(None, f.close) async def test_text_io_payload_size_utf16(tmp_path: Path) -> None: """Test TextIOPayload.size reports correct size with utf-16.""" # Create UTF-16 file utf16_file = tmp_path / "test_utf16.txt" content = "Hello World" loop = asyncio.get_running_loop() # Write file in executor await loop.run_in_executor(None, utf16_file.write_text, content, "utf-16") # Get file size in executor utf16_file_size = await loop.run_in_executor( None, lambda: utf16_file.stat().st_size ) # Open file in executor def open_file() -> TextIO: return open(utf16_file, encoding="utf-16") f = await loop.run_in_executor(None, open_file) try: tiop = payload.TextIOPayload(f, encoding="utf-16") # Payload reports file size on disk (UTF-16) assert tiop.size == utf16_file_size # Write to a buffer to see what actually gets sent writer = BufferWriter() await tiop.write(writer) # Check that the actual written bytes match file size assert len(writer.buffer) == utf16_file_size finally: await loop.run_in_executor(None, f.close) async def test_iobase_payload_size_after_reading(tmp_path: Path) -> None: """Test that IOBasePayload.size returns correct size after file has been read. This verifies that size calculation properly accounts for the initial file position, which is critical for 307/308 redirects where the same payload instance is reused. """ # Create a test file with known content test_file = tmp_path / "test.txt" content = b"Hello, World! This is test content." await asyncio.to_thread(test_file.write_bytes, content) expected_size = len(content) # Open the file and create payload f = await asyncio.to_thread(open, test_file, "rb") try: p = payload.BufferedReaderPayload(f) # First size check - should return full file size assert p.size == expected_size # Read the file (simulating first request) writer = BufferWriter() await p.write(writer) assert len(writer.buffer) == expected_size # Second size check - should still return full file size assert p.size == expected_size # Attempting to write again should write the full content writer2 = BufferWriter() await p.write(writer2) assert len(writer2.buffer) == expected_size finally: await asyncio.to_thread(f.close) async def test_iobase_payload_size_unseekable() -> None: """Test that IOBasePayload.size returns None for unseekable files.""" class UnseekableFile: """Mock file object that doesn't support seeking.""" def __init__(self, content: bytes) -> None: self.content = content self.pos = 0 def read(self, size: int) -> bytes: result = self.content[self.pos : self.pos + size] self.pos += len(result) return result def tell(self) -> int: raise OSError("Unseekable file") content = b"Unseekable content" f = UnseekableFile(content) p = payload.IOBasePayload(f) # type: ignore[arg-type] # Size should return None for unseekable files assert p.size is None # Payload should not be consumed before writing assert p.consumed is False # Writing should still work writer = BufferWriter() await p.write(writer) assert writer.buffer == content # For unseekable files that can't tell() or seek(), # they are marked as consumed after the first write assert p.consumed is True async def test_empty_bytes_payload_is_reusable() -> None: """Test that empty BytesPayload can be safely reused across requests.""" empty_payload = payload.PAYLOAD_REGISTRY.get(b"", disposition=None) assert isinstance(empty_payload, payload.BytesPayload) assert empty_payload.size == 0 assert empty_payload.consumed is False assert empty_payload.autoclose is True initial_headers = dict(empty_payload.headers) for i in range(3): writer = BufferWriter() await empty_payload.write_with_length(writer, None) assert writer.buffer == b"" assert empty_payload.consumed is False, f"consumed flag changed on write {i+1}" assert ( dict(empty_payload.headers) == initial_headers ), f"headers mutated on write {i+1}" assert empty_payload.size == 0, f"size changed on write {i+1}" assert empty_payload.headers == CIMultiDict(initial_headers) ================================================ FILE: tests/test_proxy.py ================================================ import asyncio import socket import ssl import sys from typing import Callable from unittest import mock import pytest from multidict import CIMultiDict from yarl import URL import aiohttp from aiohttp.client_reqrep import ( ClientRequest, ClientRequestArgs, ClientRequestBase, ClientResponse, Fingerprint, ) from aiohttp.connector import _SSL_CONTEXT_VERIFIED from aiohttp.helpers import TimerNoop if sys.version_info >= (3, 11): from typing import Unpack _RequestMaker = Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest] else: from typing import Any _RequestMaker = Any @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_connect( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() req = make_client_request( "GET", URL("http://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ssl=True, headers=CIMultiDict({}), ) assert str(req.proxy) == "http://proxy.example.com" connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object(connector, "_resolve_host", autospec=True, return_value=[r]): proto = mock.Mock( **{ "transport.get_extra_info.return_value": False, } ) with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(proto.transport, proto), ): conn = await connector.connect(req, [], aiohttp.ClientTimeout()) assert req.url == URL("http://www.python.org") assert conn._protocol is proto assert conn.transport is proto.transport ClientRequestMock.assert_called_with( "GET", URL("http://proxy.example.com"), auth=None, headers={"Host": "www.python.org"}, loop=event_loop, ssl=True, ) conn.close() await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_proxy_headers( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() req = make_client_request( "GET", URL("http://www.python.org"), proxy=URL("http://proxy.example.com"), proxy_headers=CIMultiDict({"Foo": "Bar"}), loop=event_loop, ssl=True, headers=CIMultiDict({}), ) assert str(req.proxy) == "http://proxy.example.com" connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object(connector, "_resolve_host", autospec=True, return_value=[r]): proto = mock.Mock( **{ "transport.get_extra_info.return_value": False, } ) with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(proto.transport, proto), ): conn = await connector.connect(req, [], aiohttp.ClientTimeout()) assert req.url == URL("http://www.python.org") assert conn._protocol is proto assert conn.transport is proto.transport ClientRequestMock.assert_called_with( "GET", URL("http://proxy.example.com"), auth=None, headers={"Host": "www.python.org", "Foo": "Bar"}, loop=event_loop, ssl=True, ) conn.close() await connector.close() @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_proxy_auth( # type: ignore[misc] start_connection: mock.Mock, make_client_request: _RequestMaker, ) -> None: msg = r"proxy_auth must be None or BasicAuth\(\) tuple" with pytest.raises(ValueError, match=msg): make_client_request( "GET", URL("http://python.org"), proxy=URL("http://proxy.example.com"), proxy_auth=("user", "pass"), # type: ignore[arg-type] loop=mock.Mock(), ) @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_proxy_dns_error( # type: ignore[misc] start_connection: mock.Mock, make_client_request: _RequestMaker, ) -> None: connector = aiohttp.TCPConnector() with mock.patch.object( connector, "_resolve_host", autospec=True, side_effect=OSError("dont take it serious"), ): req = make_client_request( "GET", URL("http://www.python.org"), proxy=URL("http://proxy.example.com"), loop=asyncio.get_running_loop(), ) expected_headers = dict(req.headers) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], aiohttp.ClientTimeout()) assert req.url.path == "/" assert dict(req.headers) == expected_headers await connector.close() @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, return_value=mock.create_autospec(socket.socket, spec_set=True, instance=True), ) async def test_proxy_connection_error( # type: ignore[misc] start_connection: mock.Mock, make_client_request: _RequestMaker, ) -> None: connector = aiohttp.TCPConnector() r = { "hostname": "www.python.org", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } with mock.patch.object(connector, "_resolve_host", autospec=True, return_value=[r]): with mock.patch.object( connector._loop, "create_connection", autospec=True, side_effect=OSError("dont take it serious"), ): req = make_client_request( "GET", URL("http://www.python.org"), proxy=URL("http://proxy.example.com"), ) with pytest.raises(aiohttp.ClientProxyConnectionError): await connector.connect(req, [], aiohttp.ClientTimeout()) await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_proxy_server_hostname_default( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( event_loop, "start_tls", autospec=True, return_value=mock.Mock(), ) as tls_m: req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) await connector._create_connection( req, [], aiohttp.ClientTimeout() ) assert ( tls_m.call_args.kwargs["server_hostname"] == "www.python.org" ) proxy_resp.close() await req._close() await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_proxy_server_hostname_override( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( event_loop, "start_tls", autospec=True, return_value=mock.Mock(), ) as tls_m: req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), server_hostname="server-hostname.example.com", ) await connector._create_connection( req, [], aiohttp.ClientTimeout() ) assert ( tls_m.call_args.kwargs["server_hostname"] == "server-hostname.example.com" ) proxy_resp.close() await req._close() await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) @pytest.mark.usefixtures("enable_cleanup_closed") @pytest.mark.parametrize("cleanup", (True, False)) async def test_https_connect_fingerprint_mismatch( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, cleanup: bool, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req class TransportMock(asyncio.Transport): def close(self) -> None: pass url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=mock.Mock(), continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) fingerprint_mock = mock.Mock(spec=Fingerprint, auto_spec=True) fingerprint_mock.check.side_effect = aiohttp.ServerFingerprintMismatch( b"exp", b"got", "example.com", 8080 ) with ( mock.patch.object( proxy_req, "_send", autospec=True, spec_set=True, return_value=proxy_resp, ), mock.patch.object( proxy_resp, "start", autospec=True, spec_set=True, return_value=mock.Mock(status=200), ), ): connector = aiohttp.TCPConnector(enable_cleanup_closed=cleanup) host = [ { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } ] with ( mock.patch.object( connector, "_resolve_host", autospec=True, spec_set=True, return_value=host, ), mock.patch.object( connector, "_get_fingerprint", autospec=True, spec_set=True, return_value=fingerprint_mock, ), mock.patch.object( # Called on connection to http://proxy.example.com event_loop, "create_connection", autospec=True, spec_set=True, return_value=(mock.Mock(), mock.Mock()), ), mock.patch.object( # Called on connection to https://www.python.org event_loop, "start_tls", autospec=True, spec_set=True, return_value=TransportMock(), ), ): req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises(aiohttp.ServerFingerprintMismatch): await connector._create_connection(req, [], aiohttp.ClientTimeout()) @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_https_connect( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( event_loop, "start_tls", autospec=True, return_value=mock.Mock(), ): req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) await connector._create_connection( req, [], aiohttp.ClientTimeout() ) assert req.url.path == "/" assert proxy_req.method == "CONNECT" assert proxy_req.url == URL("https://www.python.org") proxy_resp.close() await req._close() await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_https_connect_certificate_error( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ): tr, proto = mock.Mock(), mock.Mock() # Called on connection to http://proxy.example.com with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): # Called on connection to https://www.python.org with mock.patch.object( event_loop, "start_tls", autospec=True, side_effect=ssl.CertificateError, ): req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises(aiohttp.ClientConnectorCertificateError): await connector._create_connection( req, [], aiohttp.ClientTimeout() ) await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_https_connect_ssl_error( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ): tr, proto = mock.Mock(), mock.Mock() # Called on connection to http://proxy.example.com with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): # Called on connection to https://www.python.org with mock.patch.object( event_loop, "start_tls", autospec=True, side_effect=ssl.SSLError, ): req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises(aiohttp.ClientConnectorSSLError): await connector._create_connection( req, [], aiohttp.ClientTimeout() ) await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_https_connect_http_proxy_error( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 400 m.return_value.reason = "bad request" connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ): tr, proto = mock.Mock(), mock.Mock() tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises( aiohttp.ClientHttpProxyError, match="400, message='bad request'" ): await connector._create_connection( req, [], aiohttp.ClientTimeout() ) proxy_resp.close() await req._close() await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_https_connect_resp_start_error( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object( proxy_resp, "start", autospec=True, side_effect=OSError("error message") ): connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ): tr, proto = mock.Mock(), mock.Mock() tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises(OSError, match="error message"): await connector._create_connection( req, [], aiohttp.ClientTimeout() ) await connector.close() @mock.patch("aiohttp.connector.ClientRequest") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_request_port( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = make_client_request( "GET", URL("http://proxy.example.com"), loop=event_loop ) ClientRequestMock.return_value = proxy_req connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object(connector, "_resolve_host", autospec=True, return_value=[r]): tr, proto = mock.Mock(), mock.Mock() tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto) ): req = make_client_request( "GET", URL("http://localhost:1234/path"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) await connector._create_connection(req, [], aiohttp.ClientTimeout()) assert req.url == URL("http://localhost:1234/path") await connector.close() async def test_proxy_auth_property( event_loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: req = make_client_request( "GET", URL("http://localhost:1234/path"), proxy=URL("http://proxy.example.com"), proxy_auth=aiohttp.helpers.BasicAuth("user", "pass"), loop=event_loop, ) assert ("user", "pass", "latin1") == req.proxy_auth async def test_proxy_auth_property_default( event_loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker, ) -> None: req = make_client_request( "GET", URL("http://localhost:1234/path"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) assert req.proxy_auth is None @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_https_connect_pass_ssl_context( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=None, loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( event_loop, "start_tls", autospec=True, return_value=mock.Mock(), ) as tls_m: req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) await connector._create_connection( req, [], aiohttp.ClientTimeout() ) # ssl_shutdown_timeout=0 is not passed to start_tls tls_m.assert_called_with( mock.ANY, mock.ANY, _SSL_CONTEXT_VERIFIED, server_hostname="www.python.org", ssl_handshake_timeout=mock.ANY, ) assert req.url.path == "/" assert proxy_req.method == "CONNECT" assert proxy_req.url == URL("https://www.python.org") proxy_resp.close() await req._close() await connector.close() @mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) async def test_https_auth( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, make_client_request: _RequestMaker, ) -> None: event_loop = asyncio.get_running_loop() proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=aiohttp.helpers.BasicAuth("user", "pass"), loop=event_loop, ssl=True, headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), request_headers=CIMultiDict[str](), original_url=url, ) with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", "port": 80, "family": socket.AF_INET, "proto": 0, "flags": 0, } with mock.patch.object( connector, "_resolve_host", autospec=True, return_value=[r] ) as host_m: tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( event_loop, "start_tls", autospec=True, return_value=mock.Mock(), ): assert "AUTHORIZATION" in proxy_req.headers assert "PROXY-AUTHORIZATION" not in proxy_req.headers req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) assert "AUTHORIZATION" not in req.headers assert "PROXY-AUTHORIZATION" not in req.headers await connector._create_connection( req, [], aiohttp.ClientTimeout() ) assert req.url.path == "/" assert "AUTHORIZATION" not in req.headers assert "PROXY-AUTHORIZATION" not in req.headers assert "AUTHORIZATION" not in proxy_req.headers assert "PROXY-AUTHORIZATION" in proxy_req.headers host_m.assert_called_with( "proxy.example.com", 80, traces=mock.ANY ) proxy_resp.close() await req._close() await connector.close() ================================================ FILE: tests/test_proxy_functional.py ================================================ import asyncio import os import pathlib import platform import ssl import sys from collections.abc import Awaitable, Callable, Iterator from contextlib import suppress from re import match as match_regex from typing import TYPE_CHECKING, TypedDict from unittest import mock from uuid import uuid4 import proxy import pytest from pytest_mock import MockerFixture from yarl import URL import aiohttp from aiohttp import ClientResponse, web from aiohttp.client import _RequestOptions from aiohttp.client_exceptions import ClientConnectionError from aiohttp.pytest_plugin import AiohttpRawServer, AiohttpServer from aiohttp.test_utils import TestServer ASYNCIO_SUPPORTS_TLS_IN_TLS = sys.version_info >= (3, 11) class _ResponseArgs(TypedDict): status: int headers: dict[str, str] | None body: bytes | None if sys.version_info >= (3, 11) and TYPE_CHECKING: from typing import Unpack async def get_request( method: str = "GET", *, url: str | URL, trust_env: bool = False, **kwargs: Unpack[_RequestOptions], ) -> ClientResponse: ... else: from typing import Any async def get_request( method: str = "GET", *, url: str | URL, trust_env: bool = False, **kwargs: Any, ) -> ClientResponse: connector = aiohttp.TCPConnector(ssl=False) async with aiohttp.ClientSession( connector=connector, trust_env=trust_env ) as client: async with client.request(method, url, **kwargs) as resp: return resp @pytest.fixture def secure_proxy_url(tls_certificate_pem_path: str) -> Iterator[URL]: """Return the URL of an instance of a running secure proxy. This fixture also spawns that instance and tears it down after the test. """ proxypy_args = [ # --threadless does not work on windows, see # https://github.com/abhinavsingh/proxy.py/issues/492 "--threaded" if os.name == "nt" else "--threadless", "--num-workers", "1", # the tests only send one query anyway "--hostname", "127.0.0.1", # network interface to listen to "--port", "0", # ephemeral port, so that kernel allocates a free one "--cert-file", tls_certificate_pem_path, # contains both key and cert "--key-file", tls_certificate_pem_path, # contains both key and cert ] with proxy.Proxy(input_args=proxypy_args) as proxy_instance: yield URL.build( scheme="https", host=str(proxy_instance.flags.hostname), port=proxy_instance.flags.port, ) @pytest.fixture def web_server_endpoint_payload() -> str: return str(uuid4()) @pytest.fixture(params=("http", "https")) def web_server_endpoint_type(request: pytest.FixtureRequest) -> str: return request.param # type: ignore[no-any-return] @pytest.fixture async def web_server_endpoint_url( aiohttp_server: AiohttpServer, ssl_ctx: ssl.SSLContext, web_server_endpoint_payload: str, web_server_endpoint_type: str, ) -> URL: async def handler(request: web.Request) -> web.Response: return web.Response(text=web_server_endpoint_payload) app = web.Application() app.router.add_route("GET", "/", handler) if web_server_endpoint_type == "https": server = await aiohttp_server(app, ssl=ssl_ctx) else: server = await aiohttp_server(app) return URL.build( scheme=web_server_endpoint_type, host=server.host, port=server.port, ) @pytest.mark.skipif( not ASYNCIO_SUPPORTS_TLS_IN_TLS, reason="asyncio on this python does not support TLS in TLS", ) @pytest.mark.parametrize("web_server_endpoint_type", ("http", "https")) @pytest.mark.filterwarnings(r"ignore:.*ssl.OP_NO_SSL*") # Filter out the warning from # https://github.com/abhinavsingh/proxy.py/blob/30574fd0414005dfa8792a6e797023e862bdcf43/proxy/common/utils.py#L226 # otherwise this test will fail because the proxy will die with an error. @pytest.mark.usefixtures("loop") async def test_secure_https_proxy_absolute_path( client_ssl_ctx: ssl.SSLContext, secure_proxy_url: URL, web_server_endpoint_url: URL, web_server_endpoint_payload: str, ) -> None: """Ensure HTTP(S) sites are accessible through a secure proxy.""" conn = aiohttp.TCPConnector() sess = aiohttp.ClientSession(connector=conn) async with sess.get( web_server_endpoint_url, proxy=secure_proxy_url, ssl=client_ssl_ctx, # used for both proxy and endpoint connections ) as response: assert response.status == 200 assert await response.text() == web_server_endpoint_payload await sess.close() await conn.close() await asyncio.sleep(0.1) @pytest.mark.parametrize("web_server_endpoint_type", ("https",)) @pytest.mark.usefixtures("loop") @pytest.mark.skipif( ASYNCIO_SUPPORTS_TLS_IN_TLS, reason="asyncio on this python supports TLS in TLS" ) @pytest.mark.filterwarnings(r"ignore:.*ssl.OP_NO_SSL*") # Filter out the warning from # https://github.com/abhinavsingh/proxy.py/blob/30574fd0414005dfa8792a6e797023e862bdcf43/proxy/common/utils.py#L226 # otherwise this test will fail because the proxy will die with an error. async def test_https_proxy_unsupported_tls_in_tls( client_ssl_ctx: ssl.SSLContext, secure_proxy_url: URL, web_server_endpoint_type: str, ) -> None: """Ensure connecting to TLS endpoints w/ HTTPS proxy needs patching. This also checks that a helpful warning on how to patch the env is displayed. """ url = URL.build(scheme=web_server_endpoint_type, host="python.org") assert url.host is not None escaped_host_port = ":".join((url.host.replace(".", r"\."), str(url.port))) escaped_proxy_url = str(secure_proxy_url).replace(".", r"\.") conn = aiohttp.TCPConnector() sess = aiohttp.ClientSession(connector=conn) expected_warning_text = ( r"^" r"An HTTPS request is being sent through an HTTPS proxy\. " "This support for TLS in TLS is known to be disabled " r"in the stdlib asyncio\. This is why you'll probably see " r"an error in the log below\.\n\n" r"It is possible to enable it via monkeypatching\. " r"For more details, see:\n" r"\* https://bugs\.python\.org/issue37179\n" r"\* https://github\.com/python/cpython/pull/28073\n\n" r"You can temporarily patch this as follows:\n" r"\* https://docs\.aiohttp\.org/en/stable/client_advanced\.html#proxy-support\n" r"\* https://github\.com/aio-libs/aiohttp/discussions/6044\n$" ) type_err = ( r"transport is not supported by start_tls\(\)" ) expected_exception_reason = ( r"^" "Cannot initialize a TLS-in-TLS connection to host " f"{escaped_host_port!s} through an underlying connection " f"to an HTTPS proxy {escaped_proxy_url!s} ssl:{client_ssl_ctx!s} " f"[{type_err!s}]" r"$" ) with ( pytest.warns( RuntimeWarning, match=expected_warning_text, ), pytest.raises( ClientConnectionError, match=expected_exception_reason, ) as conn_err, ): async with sess.get(url, proxy=secure_proxy_url, ssl=client_ssl_ctx): pass assert isinstance(conn_err.value.__cause__, TypeError) assert match_regex(f"^{type_err!s}$", str(conn_err.value.__cause__)) await sess.close() await conn.close() await asyncio.sleep(0.1) @pytest.mark.skipif( platform.system() == "Windows" or sys.implementation.name != "cpython", reason="uvloop is not supported on Windows and non-CPython implementations", ) @pytest.mark.filterwarnings(r"ignore:.*ssl.OP_NO_SSL*") # Filter out the warning from # https://github.com/abhinavsingh/proxy.py/blob/30574fd0414005dfa8792a6e797023e862bdcf43/proxy/common/utils.py#L226 # otherwise this test will fail because the proxy will die with an error. async def test_uvloop_secure_https_proxy( client_ssl_ctx: ssl.SSLContext, ssl_ctx: ssl.SSLContext, secure_proxy_url: URL, uvloop_loop: asyncio.AbstractEventLoop, ) -> None: """Ensure HTTPS sites are accessible through a secure proxy without warning when using uvloop.""" payload = str(uuid4()) async def handler(request: web.Request) -> web.Response: return web.Response(text=payload) app = web.Application() app.router.add_route("GET", "/", handler) server = TestServer(app, host="127.0.0.1") await server.start_server(ssl=ssl_ctx) url = URL.build(scheme="https", host=server.host, port=server.port) conn = aiohttp.TCPConnector(force_close=True) sess = aiohttp.ClientSession(connector=conn) try: async with sess.get( url, proxy=secure_proxy_url, ssl=client_ssl_ctx ) as response: assert response.status == 200 assert await response.text() == payload finally: await sess.close() await conn.close() await server.close() await asyncio.sleep(0) await asyncio.sleep(0.1) @pytest.fixture def proxy_test_server( aiohttp_raw_server: AiohttpRawServer, loop: asyncio.AbstractEventLoop, monkeypatch: pytest.MonkeyPatch, ) -> Callable[[], Awaitable[mock.Mock]]: # Handle all proxy requests and imitate remote server response. _patch_ssl_transport(monkeypatch) default_response = _ResponseArgs(status=200, headers=None, body=None) proxy_mock = mock.Mock() async def proxy_handler(request: web.Request) -> web.Response: proxy_mock.request = request proxy_mock.requests_list.append(request) response = default_response.copy() if isinstance(proxy_mock.return_value, dict): response.update(proxy_mock.return_value) # type: ignore[typeddict-item] headers = response["headers"] if not headers: headers = {} if request.method == "CONNECT": response["body"] = None response["headers"] = headers resp = web.Response(**response) await resp.prepare(request) await resp.write_eof() return resp async def proxy_server() -> mock.Mock: proxy_mock.request = None proxy_mock.auth = None proxy_mock.requests_list = [] server = await aiohttp_raw_server(proxy_handler) # type: ignore[arg-type] proxy_mock.server = server proxy_mock.url = server.make_url("/") return proxy_mock return proxy_server async def test_proxy_http_absolute_path( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "http://aiohttp.io/path?query=yes" proxy = await proxy_test_server() await get_request(url=url, proxy=proxy.url) assert len(proxy.requests_list) == 1 assert proxy.request.method == "GET" assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == "/path?query=yes" async def test_proxy_http_raw_path( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "http://aiohttp.io:2561/space sheep?q=can:fly" raw_url = "/space%20sheep?q=can:fly" proxy = await proxy_test_server() await get_request(url=url, proxy=proxy.url) assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == raw_url async def test_proxy_http_idna_support( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "http://éé.com/" proxy = await proxy_test_server() await get_request(url=url, proxy=proxy.url) assert proxy.request.host == "éé.com" assert proxy.request.path_qs == "/" async def test_proxy_http_connection_error() -> None: url = "http://aiohttp.io/path" proxy_url = "http://localhost:2242/" with pytest.raises(aiohttp.ClientConnectorError): await get_request(url=url, proxy=proxy_url) async def test_proxy_http_bad_response( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "http://aiohttp.io/path" proxy = await proxy_test_server() proxy.return_value = dict(status=502, headers={"Proxy-Agent": "TestProxy"}) resp = await get_request(url=url, proxy=proxy.url) assert resp.status == 502 assert resp.headers["Proxy-Agent"] == "TestProxy" async def test_proxy_http_auth( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "http://aiohttp.io/path" proxy = await proxy_test_server() await get_request(url=url, proxy=proxy.url) assert "Authorization" not in proxy.request.headers assert "Proxy-Authorization" not in proxy.request.headers auth = aiohttp.BasicAuth("user", "pass") await get_request(url=url, auth=auth, proxy=proxy.url) assert "Authorization" in proxy.request.headers assert "Proxy-Authorization" not in proxy.request.headers await get_request(url=url, proxy_auth=auth, proxy=proxy.url) assert "Authorization" not in proxy.request.headers assert "Proxy-Authorization" in proxy.request.headers await get_request(url=url, auth=auth, proxy_auth=auth, proxy=proxy.url) assert "Authorization" in proxy.request.headers assert "Proxy-Authorization" in proxy.request.headers async def test_proxy_http_auth_utf8( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "http://aiohttp.io/path" auth = aiohttp.BasicAuth("юзер", "пасс", "utf-8") proxy = await proxy_test_server() await get_request(url=url, auth=auth, proxy=proxy.url) assert "Authorization" in proxy.request.headers assert "Proxy-Authorization" not in proxy.request.headers async def test_proxy_http_auth_from_url( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "http://aiohttp.io/path" proxy = await proxy_test_server() auth_url = URL(url).with_user("user").with_password("pass") await get_request(url=auth_url, proxy=proxy.url) assert "Authorization" in proxy.request.headers assert "Proxy-Authorization" not in proxy.request.headers proxy_url = URL(proxy.url).with_user("user").with_password("pass") await get_request(url=url, proxy=proxy_url) assert "Authorization" not in proxy.request.headers assert "Proxy-Authorization" in proxy.request.headers async def test_proxy_http_acquired_cleanup( proxy_test_server: Callable[[], Awaitable[mock.Mock]], loop: asyncio.AbstractEventLoop, ) -> None: url = "http://aiohttp.io/path" conn = aiohttp.TCPConnector() sess = aiohttp.ClientSession(connector=conn) proxy = await proxy_test_server() assert 0 == len(conn._acquired) async with sess.get(url, proxy=proxy.url) as resp: pass assert resp.closed assert 0 == len(conn._acquired) await sess.close() await conn.close() @pytest.mark.skip("we need to reconsider how we test this") async def test_proxy_http_acquired_cleanup_force( proxy_test_server: Callable[[], Awaitable[mock.Mock]], loop: asyncio.AbstractEventLoop, ) -> None: url = "http://aiohttp.io/path" conn = aiohttp.TCPConnector(force_close=True) sess = aiohttp.ClientSession(connector=conn) proxy = await proxy_test_server() assert 0 == len(conn._acquired) async def request() -> None: async with sess.get(url, proxy=proxy.url): assert 1 == len(conn._acquired) await request() assert 0 == len(conn._acquired) await sess.close() await conn.close() @pytest.mark.skip("we need to reconsider how we test this") async def test_proxy_http_multi_conn_limit( proxy_test_server: Callable[[], Awaitable[mock.Mock]], loop: asyncio.AbstractEventLoop, ) -> None: url = "http://aiohttp.io/path" limit, multi_conn_num = 1, 5 conn = aiohttp.TCPConnector(limit=limit) sess = aiohttp.ClientSession(connector=conn) proxy = await proxy_test_server() current_pid = None async def request(pid: int) -> ClientResponse: # process requests only one by one nonlocal current_pid async with sess.get(url, proxy=proxy.url) as resp: current_pid = pid await asyncio.sleep(0.2) assert current_pid == pid return resp requests = [request(pid) for pid in range(multi_conn_num)] responses = await asyncio.gather(*requests) assert len(responses) == multi_conn_num assert {resp.status for resp in responses} == {200} await sess.close() await conn.close() @pytest.mark.xfail async def test_proxy_https_connect( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: proxy = await proxy_test_server() url = "https://www.google.com.ua/search?q=aiohttp proxy" await get_request(url=url, proxy=proxy.url) connect = proxy.requests_list[0] assert connect.method == "CONNECT" assert connect.path == "www.google.com.ua:443" assert connect.host == "www.google.com.ua" assert proxy.request.host == "www.google.com.ua" assert proxy.request.path_qs == "/search?q=aiohttp+proxy" @pytest.mark.xfail async def test_proxy_https_connect_with_port( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: proxy = await proxy_test_server() url = "https://secure.aiohttp.io:2242/path" await get_request(url=url, proxy=proxy.url) connect = proxy.requests_list[0] assert connect.method == "CONNECT" assert connect.path == "secure.aiohttp.io:2242" assert connect.host == "secure.aiohttp.io:2242" assert proxy.request.host == "secure.aiohttp.io:2242" assert proxy.request.path_qs == "/path" @pytest.mark.xfail async def test_proxy_https_send_body( proxy_test_server: Callable[[], Awaitable[mock.Mock]], loop: asyncio.AbstractEventLoop, ) -> None: sess = aiohttp.ClientSession() try: proxy = await proxy_test_server() proxy.return_value = {"status": 200, "body": b"1" * (2**20)} url = "https://www.google.com.ua/search?q=aiohttp proxy" async with sess.get(url, proxy=proxy.url) as resp: body = await resp.read() assert body == b"1" * (2**20) finally: await sess.close() @pytest.mark.xfail async def test_proxy_https_idna_support( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "https://éé.com/" proxy = await proxy_test_server() await get_request(url=url, proxy=proxy.url) connect = proxy.requests_list[0] assert connect.method == "CONNECT" assert connect.path == "xn--9caa.com:443" assert connect.host == "xn--9caa.com" async def test_proxy_https_connection_error() -> None: url = "https://secure.aiohttp.io/path" proxy_url = "http://localhost:2242/" with pytest.raises(aiohttp.ClientConnectorError): await get_request(url=url, proxy=proxy_url) async def test_proxy_https_bad_response( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "https://secure.aiohttp.io/path" proxy = await proxy_test_server() proxy.return_value = dict(status=502, headers={"Proxy-Agent": "TestProxy"}) with pytest.raises(aiohttp.ClientHttpProxyError): await get_request(url=url, proxy=proxy.url) assert len(proxy.requests_list) == 1 assert proxy.request.method == "CONNECT" # The following check fails on MacOS # assert proxy.request.path == 'secure.aiohttp.io:443' @pytest.mark.xfail async def test_proxy_https_auth( proxy_test_server: Callable[[], Awaitable[mock.Mock]], ) -> None: url = "https://secure.aiohttp.io/path" auth = aiohttp.BasicAuth("user", "pass") proxy = await proxy_test_server() await get_request(url=url, proxy=proxy.url) connect = proxy.requests_list[0] assert "Authorization" not in connect.headers assert "Proxy-Authorization" not in connect.headers assert "Authorization" not in proxy.request.headers assert "Proxy-Authorization" not in proxy.request.headers proxy = await proxy_test_server() await get_request(url=url, auth=auth, proxy=proxy.url) connect = proxy.requests_list[0] assert "Authorization" not in connect.headers assert "Proxy-Authorization" not in connect.headers assert "Authorization" in proxy.request.headers assert "Proxy-Authorization" not in proxy.request.headers proxy = await proxy_test_server() await get_request(url=url, proxy_auth=auth, proxy=proxy.url) connect = proxy.requests_list[0] assert "Authorization" not in connect.headers assert "Proxy-Authorization" in connect.headers assert "Authorization" not in proxy.request.headers assert "Proxy-Authorization" not in proxy.request.headers proxy = await proxy_test_server() await get_request(url=url, auth=auth, proxy_auth=auth, proxy=proxy.url) connect = proxy.requests_list[0] assert "Authorization" not in connect.headers assert "Proxy-Authorization" in connect.headers assert "Authorization" in proxy.request.headers assert "Proxy-Authorization" not in proxy.request.headers @pytest.mark.xfail async def test_proxy_https_acquired_cleanup( proxy_test_server: Callable[[], Awaitable[mock.Mock]], loop: asyncio.AbstractEventLoop, ) -> None: url = "https://secure.aiohttp.io/path" conn = aiohttp.TCPConnector() sess = aiohttp.ClientSession(connector=conn) try: proxy = await proxy_test_server() assert 0 == len(conn._acquired) async def request() -> None: async with sess.get(url, proxy=proxy.url): assert 1 == len(conn._acquired) await request() assert 0 == len(conn._acquired) finally: await sess.close() await conn.close() @pytest.mark.xfail async def test_proxy_https_acquired_cleanup_force( proxy_test_server: Callable[[], Awaitable[mock.Mock]], loop: asyncio.AbstractEventLoop, ) -> None: url = "https://secure.aiohttp.io/path" conn = aiohttp.TCPConnector(force_close=True) sess = aiohttp.ClientSession(connector=conn) try: proxy = await proxy_test_server() assert 0 == len(conn._acquired) async def request() -> None: async with sess.get(url, proxy=proxy.url): assert 1 == len(conn._acquired) await request() assert 0 == len(conn._acquired) finally: await sess.close() await conn.close() @pytest.mark.xfail async def test_proxy_https_multi_conn_limit( proxy_test_server: Callable[[], Awaitable[mock.Mock]], loop: asyncio.AbstractEventLoop, ) -> None: url = "https://secure.aiohttp.io/path" limit, multi_conn_num = 1, 5 conn = aiohttp.TCPConnector(limit=limit) sess = aiohttp.ClientSession(connector=conn) proxy = await proxy_test_server() try: current_pid = None async def request(pid: int) -> ClientResponse: # process requests only one by one nonlocal current_pid async with sess.get(url, proxy=proxy.url) as resp: current_pid = pid await asyncio.sleep(0.2) assert current_pid == pid return resp requests = [request(pid) for pid in range(multi_conn_num)] responses = await asyncio.gather(*requests, return_exceptions=True) # Filter out exceptions to count actual responses actual_responses = [r for r in responses if isinstance(r, ClientResponse)] assert len(actual_responses) == multi_conn_num assert {resp.status for resp in actual_responses} == {200} finally: await sess.close() await conn.close() def _patch_ssl_transport(monkeypatch: pytest.MonkeyPatch) -> None: # Make ssl transport substitution to prevent ssl handshake. def _make_ssl_transport_dummy( self: asyncio.selector_events.BaseSelectorEventLoop, rawsock: object, protocol: object, sslcontext: object, waiter: object = None, **kwargs: object, ) -> object: return self._make_socket_transport( # type: ignore[attr-defined] rawsock, protocol, waiter, extra=kwargs.get("extra"), server=kwargs.get("server"), ) monkeypatch.setattr( "asyncio.selector_events.BaseSelectorEventLoop._make_ssl_transport", _make_ssl_transport_dummy, ) original_is_file = pathlib.Path.is_file def mock_is_file(self: pathlib.Path) -> bool: # make real netrc file invisible in home dir if self.name in ["_netrc", ".netrc"] and self.parent == self.home(): return False else: return original_is_file(self) async def test_proxy_from_env_http( proxy_test_server: Callable[[], Awaitable[mock.Mock]], mocker: MockerFixture ) -> None: url = "http://aiohttp.io/path" proxy = await proxy_test_server() mocker.patch.dict(os.environ, {"http_proxy": str(proxy.url)}) mocker.patch("pathlib.Path.is_file", mock_is_file) await get_request(url=url, trust_env=True) assert len(proxy.requests_list) == 1 assert proxy.request.method == "GET" assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == "/path" assert "Proxy-Authorization" not in proxy.request.headers async def test_proxy_from_env_http_with_auth( proxy_test_server: Callable[[], Awaitable[mock.Mock]], mocker: MockerFixture ) -> None: url = "http://aiohttp.io/path" proxy = await proxy_test_server() auth = aiohttp.BasicAuth("user", "pass") mocker.patch.dict( os.environ, { "http_proxy": str( proxy.url.with_user(auth.login).with_password(auth.password) ) }, ) await get_request(url=url, trust_env=True) assert len(proxy.requests_list) == 1 assert proxy.request.method == "GET" assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == "/path" assert proxy.request.headers["Proxy-Authorization"] == auth.encode() async def test_proxy_from_env_http_with_auth_from_netrc( proxy_test_server: Callable[[], Awaitable[mock.Mock]], tmp_path: pathlib.Path, mocker: MockerFixture, ) -> None: url = "http://aiohttp.io/path" proxy = await proxy_test_server() auth = aiohttp.BasicAuth("user", "pass") netrc_file = tmp_path / "test_netrc" netrc_file_data = f"machine 127.0.0.1 login {auth.login} password {auth.password}" with netrc_file.open("w") as f: f.write(netrc_file_data) mocker.patch.dict( os.environ, {"http_proxy": str(proxy.url), "NETRC": str(netrc_file)} ) await get_request(url=url, trust_env=True) assert len(proxy.requests_list) == 1 assert proxy.request.method == "GET" assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == "/path" assert proxy.request.headers["Proxy-Authorization"] == auth.encode() async def test_proxy_from_env_http_without_auth_from_netrc( proxy_test_server: Callable[[], Awaitable[mock.Mock]], tmp_path: pathlib.Path, mocker: MockerFixture, ) -> None: url = "http://aiohttp.io/path" proxy = await proxy_test_server() auth = aiohttp.BasicAuth("user", "pass") netrc_file = tmp_path / "test_netrc" netrc_file_data = f"machine 127.0.0.2 login {auth.login} password {auth.password}" with netrc_file.open("w") as f: f.write(netrc_file_data) mocker.patch.dict( os.environ, {"http_proxy": str(proxy.url), "NETRC": str(netrc_file)} ) await get_request(url=url, trust_env=True) assert len(proxy.requests_list) == 1 assert proxy.request.method == "GET" assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == "/path" assert "Proxy-Authorization" not in proxy.request.headers async def test_proxy_from_env_http_without_auth_from_wrong_netrc( proxy_test_server: Callable[[], Awaitable[mock.Mock]], tmp_path: pathlib.Path, mocker: MockerFixture, ) -> None: url = "http://aiohttp.io/path" proxy = await proxy_test_server() auth = aiohttp.BasicAuth("user", "pass") netrc_file = tmp_path / "test_netrc" invalid_data = f"machine 127.0.0.1 {auth.login} pass {auth.password}" with netrc_file.open("w") as f: f.write(invalid_data) mocker.patch.dict( os.environ, {"http_proxy": str(proxy.url), "NETRC": str(netrc_file)} ) await get_request(url=url, trust_env=True) assert len(proxy.requests_list) == 1 assert proxy.request.method == "GET" assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == "/path" assert "Proxy-Authorization" not in proxy.request.headers @pytest.mark.xfail async def test_proxy_from_env_https( proxy_test_server: Callable[[], Awaitable[mock.Mock]], mocker: MockerFixture ) -> None: url = "https://aiohttp.io/path" proxy = await proxy_test_server() mocker.patch.dict(os.environ, {"https_proxy": str(proxy.url)}) mocker.patch("pathlib.Path.is_file", mock_is_file) await get_request(url=url, trust_env=True) assert len(proxy.requests_list) == 2 assert proxy.request.method == "GET" assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == "/path" assert "Proxy-Authorization" not in proxy.request.headers @pytest.mark.xfail async def test_proxy_from_env_https_with_auth( proxy_test_server: Callable[[], Awaitable[mock.Mock]], mocker: MockerFixture ) -> None: url = "https://aiohttp.io/path" proxy = await proxy_test_server() auth = aiohttp.BasicAuth("user", "pass") mocker.patch.dict( os.environ, { "https_proxy": str( proxy.url.with_user(auth.login).with_password(auth.password) ) }, ) await get_request(url=url, trust_env=True) assert len(proxy.requests_list) == 2 assert proxy.request.method == "GET" assert proxy.request.host == "aiohttp.io" assert proxy.request.path_qs == "/path" assert "Proxy-Authorization" not in proxy.request.headers r2 = proxy.requests_list[0] assert r2.method == "CONNECT" assert r2.host == "aiohttp.io" assert r2.path_qs == "/path" assert r2.headers["Proxy-Authorization"] == auth.encode() async def test_proxy_auth() -> None: async with aiohttp.ClientSession() as session: with pytest.raises( ValueError, match=r"proxy_auth must be None or BasicAuth\(\) tuple" ): async with session.get( "http://python.org", proxy="http://proxy.example.com", proxy_auth=("user", "pass"), # type: ignore[arg-type] ): pass async def test_https_proxy_connect_tunnel_session_close_no_hang( aiohttp_server: AiohttpServer, ) -> None: """Test that CONNECT tunnel connections are not pooled.""" # Regression test for issue #11273. # Create a minimal proxy server # The CONNECT method is handled at the protocol level, not by the handler proxy_app = web.Application() proxy_server = await aiohttp_server(proxy_app) proxy_url = f"http://{proxy_server.host}:{proxy_server.port}" # Create session and make HTTPS request through proxy session = aiohttp.ClientSession() try: # This will fail during TLS upgrade because proxy doesn't establish tunnel with suppress(aiohttp.ClientError): async with session.get("https://example.com/test", proxy=proxy_url) as resp: await resp.read() # The critical test: Check if any connections were pooled with proxy=None # This is the root cause of the hang - CONNECT tunnel connections # should NOT be pooled connector = session.connector assert connector is not None # Count connections with proxy=None in the pool proxy_none_keys = [key for key in connector._conns if key.proxy is None] proxy_none_count = len(proxy_none_keys) # Before the fix, there would be a connection with proxy=None # After the fix, CONNECT tunnel connections are not pooled assert proxy_none_count == 0, ( f"Found {proxy_none_count} connections with proxy=None in pool. " f"CONNECT tunnel connections should not be pooled - this is bug #11273" ) finally: # Clean close await session.close() ================================================ FILE: tests/test_pytest_plugin.py ================================================ import os import platform import warnings import pytest from aiohttp import pytest_plugin pytest_plugins: str = "pytester" CONFTEST: str = """ pytest_plugins = 'aiohttp.pytest_plugin' """ IS_PYPY = platform.python_implementation() == "PyPy" def test_aiohttp_plugin(testdir: pytest.Testdir) -> None: testdir.makepyfile("""\ import pytest from unittest import mock from aiohttp import web value = web.AppKey('value', str) async def hello(request): return web.Response(body=b'Hello, world') async def create_app(): app = web.Application() app.router.add_route('GET', '/', hello) return app async def test_hello(aiohttp_client) -> None: client = await aiohttp_client(await create_app()) resp = await client.get('/') assert resp.status == 200 text = await resp.text() assert 'Hello, world' in text async def test_hello_from_app(aiohttp_client) -> None: app = web.Application() app.router.add_get('/', hello) client = await aiohttp_client(app) resp = await client.get('/') assert resp.status == 200 text = await resp.text() assert 'Hello, world' in text async def test_hello_with_loop(aiohttp_client) -> None: client = await aiohttp_client(await create_app()) resp = await client.get('/') assert resp.status == 200 text = await resp.text() assert 'Hello, world' in text async def test_noop() -> None: pass async def previous(request): if request.method == 'POST': with pytest.deprecated_call(): # FIXME: this isn't actually called request.app[value] = (await request.post())['value'] return web.Response(body=b'thanks for the data') else: v = request.app.get(value, 'unknown') return web.Response(body='value: {}'.format(v).encode()) def create_stateful_app(): app = web.Application() app.router.add_route('*', '/', previous) return app @pytest.fixture def cli(loop, aiohttp_client): return loop.run_until_complete(aiohttp_client(create_stateful_app())) def test_noncoro() -> None: assert True async def test_failed_to_create_client(aiohttp_client) -> None: def make_app(): raise RuntimeError() with pytest.raises(RuntimeError): await aiohttp_client(make_app()) async def test_custom_port_aiohttp_client(aiohttp_client, aiohttp_unused_port): port = aiohttp_unused_port() client = await aiohttp_client(await create_app(), server_kwargs={'port': port}) assert client.port == port resp = await client.get('/') assert resp.status == 200 text = await resp.text() assert 'Hello, world' in text async def test_custom_port_test_server(aiohttp_server, aiohttp_unused_port): app = await create_app() port = aiohttp_unused_port() server = await aiohttp_server(app, port=port) assert server.port == port """) testdir.makeconftest(CONFTEST) result = testdir.runpytest("-p", "no:sugar", "--aiohttp-loop=pyloop") result.assert_outcomes(passed=8) def test_warning_checks(testdir: pytest.Testdir) -> None: testdir.makepyfile("""\ async def foobar(): return 123 async def test_good() -> None: v = await foobar() assert v == 123 async def test_bad() -> None: foobar() """) testdir.makeconftest(CONFTEST) result = testdir.runpytest( "-p", "no:sugar", "-s", "-W", "default", "--aiohttp-loop=pyloop" ) expected_outcomes = ( {"failed": 0, "passed": 2} if IS_PYPY and bool(os.environ.get("PYTHONASYNCIODEBUG")) else {"failed": 1, "passed": 1} ) # Under PyPy "coroutine 'foobar' was never awaited" does not happen. result.assert_outcomes(**expected_outcomes) def test_aiohttp_plugin_async_fixture( testdir: pytest.Testdir, capsys: pytest.CaptureFixture[str] ) -> None: testdir.makepyfile("""\ import pytest from aiohttp import web async def hello(request): return web.Response(body=b'Hello, world') def create_app(): app = web.Application() app.router.add_route('GET', '/', hello) return app @pytest.fixture async def cli(aiohttp_client, loop): client = await aiohttp_client(create_app()) return client @pytest.fixture async def foo(): return 42 @pytest.fixture async def bar(request): # request should be accessible in async fixtures if needed return request.function async def test_hello(cli, loop) -> None: resp = await cli.get('/') assert resp.status == 200 def test_foo(loop, foo) -> None: assert foo == 42 def test_foo_without_loop(foo) -> None: # will raise an error because there is no loop pass def test_bar(loop, bar) -> None: assert bar is test_bar """) testdir.makeconftest(CONFTEST) result = testdir.runpytest("-p", "no:sugar", "--aiohttp-loop=pyloop") result.assert_outcomes(passed=3, errors=1) result.stdout.fnmatch_lines( "*Asynchronous fixtures must depend on the 'loop' fixture " "or be used in tests depending from it." ) def test_aiohttp_plugin_async_gen_fixture(testdir: pytest.Testdir) -> None: testdir.makepyfile("""\ import pytest from unittest import mock from aiohttp import web canary = mock.Mock() async def hello(request): return web.Response(body=b'Hello, world') def create_app(): app = web.Application() app.router.add_route('GET', '/', hello) return app @pytest.fixture async def cli(aiohttp_client, loop): yield await aiohttp_client(create_app()) canary() async def test_hello(cli) -> None: resp = await cli.get('/') assert resp.status == 200 def test_finalized() -> None: assert canary.called is True """) testdir.makeconftest(CONFTEST) result = testdir.runpytest("-p", "no:sugar", "--aiohttp-loop=pyloop") result.assert_outcomes(passed=2) def test_warnings_propagated(recwarn: pytest.WarningsRecorder) -> None: with pytest_plugin._runtime_warning_context(): warnings.warn("test warning is propagated") assert len(recwarn) == 1 message = recwarn[0].message assert isinstance(message, UserWarning) assert message.args == ("test warning is propagated",) def test_aiohttp_client_cls_fixture_custom_client_used(testdir: pytest.Testdir) -> None: testdir.makepyfile(""" import pytest from aiohttp.web import Application from aiohttp.test_utils import TestClient class CustomClient(TestClient): pass @pytest.fixture def aiohttp_client_cls(): return CustomClient async def test_hello(aiohttp_client) -> None: client = await aiohttp_client(Application()) assert isinstance(client, CustomClient) """) testdir.makeconftest(CONFTEST) result = testdir.runpytest() result.assert_outcomes(passed=1) def test_aiohttp_client_cls_fixture_factory(testdir: pytest.Testdir) -> None: testdir.makeconftest(CONFTEST + """ def pytest_configure(config): config.addinivalue_line("markers", "rest: RESTful API tests") config.addinivalue_line("markers", "graphql: GraphQL API tests") """) testdir.makepyfile(""" import pytest from aiohttp.web import Application from aiohttp.test_utils import TestClient class RESTfulClient(TestClient): pass class GraphQLClient(TestClient): pass @pytest.fixture def aiohttp_client_cls(request): if request.node.get_closest_marker('rest') is not None: return RESTfulClient elif request.node.get_closest_marker('graphql') is not None: return GraphQLClient return TestClient @pytest.mark.rest async def test_rest(aiohttp_client) -> None: client = await aiohttp_client(Application()) assert isinstance(client, RESTfulClient) @pytest.mark.graphql async def test_graphql(aiohttp_client) -> None: client = await aiohttp_client(Application()) assert isinstance(client, GraphQLClient) """) result = testdir.runpytest() result.assert_outcomes(passed=2) ================================================ FILE: tests/test_resolver.py ================================================ import asyncio import gc import ipaddress import socket from collections.abc import Awaitable, Callable, Collection, Generator from ipaddress import ip_address from typing import Any, NamedTuple from unittest.mock import Mock, create_autospec, patch import pytest from aiohttp.resolver import ( _NAME_SOCKET_FLAGS, AsyncResolver, DefaultResolver, ThreadedResolver, _DNSResolverManager, ) try: import aiodns getaddrinfo = hasattr(aiodns.DNSResolver, "getaddrinfo") except ImportError: # pragma: no cover aiodns = None # type: ignore[assignment] getaddrinfo = False _AddrInfo4 = list[ tuple[socket.AddressFamily, None, socket.SocketKind, None, tuple[str, int]] ] _AddrInfo6 = list[ tuple[ socket.AddressFamily, None, socket.SocketKind, None, tuple[str, int, int, int] ] ] _UnknownAddrInfo = list[ tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[int, bytes]] ] @pytest.fixture() def check_no_lingering_resolvers() -> Generator[None, None, None]: """Verify no resolvers remain after the test. This fixture should be used in any test that creates instances of AsyncResolver or directly uses _DNSResolverManager. """ manager = _DNSResolverManager() before = len(manager._loop_data) yield after = len(manager._loop_data) if after > before: # pragma: no branch # Force garbage collection to ensure weak references are updated gc.collect() # pragma: no cover after = len(manager._loop_data) # pragma: no cover if after > before: # pragma: no cover pytest.fail( # pragma: no cover f"Lingering resolvers found: {(after - before)} " "new AsyncResolver instances were not properly closed." ) @pytest.fixture() def dns_resolver_manager() -> Generator[_DNSResolverManager, None, None]: """Create a fresh _DNSResolverManager instance for testing. Saves and restores the singleton state to avoid affecting other tests. """ # Save the original instance original_instance = _DNSResolverManager._instance # Reset the singleton _DNSResolverManager._instance = None # Create and yield a fresh instance try: yield _DNSResolverManager() finally: # Clean up and restore the original instance _DNSResolverManager._instance = original_instance class FakeAIODNSAddrInfoNode(NamedTuple): family: int addr: tuple[bytes, int] | tuple[bytes, int, int, int] class FakeAIODNSAddrInfoIPv4Result: def __init__(self, hosts: Collection[str]) -> None: self.nodes = [ FakeAIODNSAddrInfoNode(socket.AF_INET, (h.encode(), 0)) for h in hosts ] class FakeAIODNSAddrInfoIPv6Result: def __init__(self, hosts: Collection[str]) -> None: self.nodes = [ FakeAIODNSAddrInfoNode( socket.AF_INET6, (h.encode(), 0, 0, 3 if ip_address(h).is_link_local else 0), ) for h in hosts ] class FakeAIODNSNameInfoIPv6Result: def __init__(self, host: str) -> None: self.node = host self.service = None async def fake_aiodns_getaddrinfo_ipv4_result( hosts: Collection[str], ) -> FakeAIODNSAddrInfoIPv4Result: return FakeAIODNSAddrInfoIPv4Result(hosts=hosts) async def fake_aiodns_getaddrinfo_ipv6_result( hosts: Collection[str], ) -> FakeAIODNSAddrInfoIPv6Result: return FakeAIODNSAddrInfoIPv6Result(hosts=hosts) async def fake_aiodns_getnameinfo_ipv6_result( host: str, ) -> FakeAIODNSNameInfoIPv6Result: return FakeAIODNSNameInfoIPv6Result(host) def fake_addrinfo(hosts: Collection[str]) -> Callable[..., Awaitable[_AddrInfo4]]: async def fake(*args: Any, **kwargs: Any) -> _AddrInfo4: if not hosts: raise socket.gaierror return [(socket.AF_INET, None, socket.SOCK_STREAM, None, (h, 0)) for h in hosts] return fake def fake_ipv6_addrinfo(hosts: Collection[str]) -> Callable[..., Awaitable[_AddrInfo6]]: async def fake(*args: Any, **kwargs: Any) -> _AddrInfo6: if not hosts: raise socket.gaierror return [ ( socket.AF_INET6, None, socket.SOCK_STREAM, None, (h, 0, 0, 3 if ip_address(h).is_link_local else 0), ) for h in hosts ] return fake def fake_ipv6_nameinfo(host: str) -> Callable[..., Awaitable[tuple[str, int]]]: async def fake(*args: Any, **kwargs: Any) -> tuple[str, int]: return host, 0 return fake @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_positive_ipv4_lookup( loop: asyncio.AbstractEventLoop, ) -> None: with patch("aiodns.DNSResolver") as mock: mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv4_result( ["127.0.0.1"] ) resolver = AsyncResolver() real = await resolver.resolve("www.python.org") ipaddress.ip_address(real[0]["host"]) mock().getaddrinfo.assert_called_with( "www.python.org", family=socket.AF_INET, flags=socket.AI_ADDRCONFIG, port=0, type=socket.SOCK_STREAM, ) await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_positive_link_local_ipv6_lookup( loop: asyncio.AbstractEventLoop, ) -> None: with patch("aiodns.DNSResolver") as mock: mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv6_result( ["fe80::1"] ) mock().getnameinfo.return_value = fake_aiodns_getnameinfo_ipv6_result( "fe80::1%eth0" ) resolver = AsyncResolver() real = await resolver.resolve("www.python.org") ipaddress.ip_address(real[0]["host"]) mock().getaddrinfo.assert_called_with( "www.python.org", family=socket.AF_INET, flags=socket.AI_ADDRCONFIG, port=0, type=socket.SOCK_STREAM, ) mock().getnameinfo.assert_called_with(("fe80::1", 0, 0, 3), _NAME_SOCKET_FLAGS) await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_multiple_replies(loop: asyncio.AbstractEventLoop) -> None: with patch("aiodns.DNSResolver") as mock: ips = ["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"] mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv4_result(ips) resolver = AsyncResolver() real = await resolver.resolve("www.google.com") ipaddrs = [ipaddress.ip_address(x["host"]) for x in real] assert len(ipaddrs) > 3, "Expecting multiple addresses" await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_negative_lookup(loop: asyncio.AbstractEventLoop) -> None: with patch("aiodns.DNSResolver") as mock: mock().getaddrinfo.side_effect = aiodns.error.DNSError() resolver = AsyncResolver() with pytest.raises(OSError): await resolver.resolve("doesnotexist.bla") await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_no_hosts_in_getaddrinfo( loop: asyncio.AbstractEventLoop, ) -> None: with patch("aiodns.DNSResolver") as mock: mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv4_result([]) resolver = AsyncResolver() with pytest.raises(OSError): await resolver.resolve("doesnotexist.bla") await resolver.close() async def test_threaded_resolver_positive_lookup() -> None: loop = Mock() loop.getaddrinfo = fake_addrinfo(["127.0.0.1"]) resolver = ThreadedResolver() resolver._loop = loop real = await resolver.resolve("www.python.org") assert real[0]["hostname"] == "www.python.org" ipaddress.ip_address(real[0]["host"]) async def test_threaded_resolver_positive_ipv6_link_local_lookup() -> None: loop = Mock() loop.getaddrinfo = fake_ipv6_addrinfo(["fe80::1"]) loop.getnameinfo = fake_ipv6_nameinfo("fe80::1%eth0") # Mock the fake function that was returned by helper functions loop.getaddrinfo = create_autospec(loop.getaddrinfo) loop.getnameinfo = create_autospec(loop.getnameinfo) # Set the correct return values for mock functions loop.getaddrinfo.return_value = await fake_ipv6_addrinfo(["fe80::1"])() loop.getnameinfo.return_value = await fake_ipv6_nameinfo("fe80::1%eth0")() resolver = ThreadedResolver() resolver._loop = loop real = await resolver.resolve("www.python.org") assert real[0]["hostname"] == "www.python.org" ipaddress.ip_address(real[0]["host"]) loop.getaddrinfo.assert_called_with( "www.python.org", 0, type=socket.SOCK_STREAM, family=socket.AF_INET, flags=socket.AI_ADDRCONFIG, ) loop.getnameinfo.assert_called_with(("fe80::1", 0, 0, 3), _NAME_SOCKET_FLAGS) async def test_threaded_resolver_multiple_replies() -> None: loop = Mock() ips = ["127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4"] loop.getaddrinfo = fake_addrinfo(ips) resolver = ThreadedResolver() resolver._loop = loop real = await resolver.resolve("www.google.com") ipaddrs = [ipaddress.ip_address(x["host"]) for x in real] assert len(ipaddrs) > 3, "Expecting multiple addresses" async def test_threaded_negative_lookup() -> None: loop = Mock() ips: list[str] = [] loop.getaddrinfo = fake_addrinfo(ips) resolver = ThreadedResolver() resolver._loop = loop with pytest.raises(socket.gaierror): await resolver.resolve("doesnotexist.bla") async def test_threaded_negative_ipv6_lookup() -> None: loop = Mock() ips: list[str] = [] loop.getaddrinfo = fake_ipv6_addrinfo(ips) resolver = ThreadedResolver() resolver._loop = loop with pytest.raises(socket.gaierror): await resolver.resolve("doesnotexist.bla") async def test_threaded_negative_lookup_with_unknown_result() -> None: loop = Mock() # If compile CPython with `--disable-ipv6` option, # we will get an (int, bytes) tuple, instead of a Exception. async def unknown_addrinfo(*args: Any, **kwargs: Any) -> _UnknownAddrInfo: return [ ( socket.AF_INET6, socket.SOCK_STREAM, 6, "", (10, b"\x01\xbb\x00\x00\x00\x00*\x04NB\x00\x1a\x00\x00"), ) ] loop.getaddrinfo = unknown_addrinfo resolver = ThreadedResolver() resolver._loop = loop with patch("socket.has_ipv6", False): res = await resolver.resolve("www.python.org") assert len(res) == 0 async def test_close_for_threaded_resolver(loop: asyncio.AbstractEventLoop) -> None: resolver = ThreadedResolver() await resolver.close() @pytest.mark.skipif(aiodns is None, reason="aiodns required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_close_for_async_resolver(loop: asyncio.AbstractEventLoop) -> None: resolver = AsyncResolver() await resolver.close() async def test_default_loop_for_threaded_resolver( loop: asyncio.AbstractEventLoop, ) -> None: asyncio.set_event_loop(loop) resolver = ThreadedResolver() assert resolver._loop is loop @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_ipv6_positive_lookup( loop: asyncio.AbstractEventLoop, ) -> None: with patch("aiodns.DNSResolver") as mock: mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv6_result(["::1"]) resolver = AsyncResolver() real = await resolver.resolve("www.python.org") ipaddress.ip_address(real[0]["host"]) mock().getaddrinfo.assert_called_with( "www.python.org", family=socket.AF_INET, flags=socket.AI_ADDRCONFIG, port=0, type=socket.SOCK_STREAM, ) await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_error_messages_passed( loop: asyncio.AbstractEventLoop, ) -> None: """Ensure error messages are passed through from aiodns.""" with patch("aiodns.DNSResolver", autospec=True, spec_set=True) as mock: mock().getaddrinfo.side_effect = aiodns.error.DNSError(1, "Test error message") resolver = AsyncResolver() with pytest.raises(OSError, match="Test error message") as excinfo: await resolver.resolve("x.org") assert excinfo.value.strerror == "Test error message" await resolver.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_error_messages_passed_no_hosts( loop: asyncio.AbstractEventLoop, ) -> None: """Ensure error messages are passed through from aiodns.""" with patch("aiodns.DNSResolver", autospec=True, spec_set=True) as mock: mock().getaddrinfo.return_value = fake_aiodns_getaddrinfo_ipv6_result([]) resolver = AsyncResolver() with pytest.raises(OSError, match="DNS lookup failed") as excinfo: await resolver.resolve("x.org") assert excinfo.value.strerror == "DNS lookup failed" await resolver.close() @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_aiodns_not_present( loop: asyncio.AbstractEventLoop, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setattr("aiohttp.resolver.aiodns", None) with pytest.raises(RuntimeError): AsyncResolver() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") def test_aio_dns_is_default() -> None: assert DefaultResolver is AsyncResolver @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") async def test_dns_resolver_manager_sharing( dns_resolver_manager: _DNSResolverManager, ) -> None: """Test that the DNSResolverManager shares a resolver among AsyncResolver instances.""" # Create two default AsyncResolver instances resolver1 = AsyncResolver() resolver2 = AsyncResolver() # Check that they share the same underlying resolver assert resolver1._resolver is resolver2._resolver # Create an AsyncResolver with custom args resolver3 = AsyncResolver(nameservers=["8.8.8.8"]) # Check that it has its own resolver assert resolver1._resolver is not resolver3._resolver # Cleanup await resolver1.close() await resolver2.close() await resolver3.close() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") async def test_dns_resolver_manager_singleton( dns_resolver_manager: _DNSResolverManager, ) -> None: """Test that DNSResolverManager is a singleton.""" # Create a second manager and check it's the same instance manager1 = dns_resolver_manager manager2 = _DNSResolverManager() assert manager1 is manager2 @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") async def test_dns_resolver_manager_resolver_lifecycle( dns_resolver_manager: _DNSResolverManager, ) -> None: """Test that DNSResolverManager creates and destroys resolver correctly.""" manager = dns_resolver_manager # Initially there should be no resolvers assert not manager._loop_data # Create a mock AsyncResolver for testing mock_client = Mock(spec=AsyncResolver) mock_client._loop = asyncio.get_running_loop() # Getting resolver should create one mock_loop = mock_client._loop resolver = manager.get_resolver(mock_client, mock_loop) assert resolver is not None assert manager._loop_data[mock_loop][0] is resolver # Getting it again should return the same instance assert manager.get_resolver(mock_client, mock_loop) is resolver # Clean up manager.release_resolver(mock_client, mock_loop) assert not manager._loop_data @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") async def test_dns_resolver_manager_client_registration( dns_resolver_manager: _DNSResolverManager, ) -> None: """Test client registration and resolver release logic.""" with patch("aiodns.DNSResolver") as mock: # Create resolver instances resolver1 = AsyncResolver() resolver2 = AsyncResolver() # Both should use the same resolver from the manager assert resolver1._resolver is resolver2._resolver # The manager should be tracking both clients assert resolver1._manager is resolver2._manager manager = resolver1._manager assert manager is not None loop = asyncio.get_running_loop() _, client_set = manager._loop_data[loop] assert len(client_set) == 2 # Close one resolver await resolver1.close() _, client_set = manager._loop_data[loop] assert len(client_set) == 1 # Resolver should still exist assert manager._loop_data # Not empty # Close the second resolver await resolver2.close() assert not manager._loop_data # Should be empty after closing all clients # Now all resolvers should be canceled and removed assert not manager._loop_data # Should be empty mock().cancel.assert_called_once() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") async def test_dns_resolver_manager_multiple_event_loops( dns_resolver_manager: _DNSResolverManager, ) -> None: """Test that DNSResolverManager correctly manages resolvers across different event loops.""" # Create separate resolvers for each loop resolver1 = Mock(name="resolver1") resolver2 = Mock(name="resolver2") # Create a patch that returns different resolvers based on the loop argument mock_resolver = Mock() mock_resolver.side_effect = lambda loop=None, **kwargs: ( resolver1 if loop is asyncio.get_running_loop() else resolver2 ) with patch("aiodns.DNSResolver", mock_resolver): manager = dns_resolver_manager # Create two mock clients on different loops mock_client1 = Mock(spec=AsyncResolver) mock_client1._loop = asyncio.get_running_loop() # Create a second event loop loop2 = Mock(spec=asyncio.AbstractEventLoop) mock_client2 = Mock(spec=AsyncResolver) mock_client2._loop = loop2 # Get resolvers for both clients loop1 = mock_client1._loop loop2 = mock_client2._loop # Get the resolvers through the manager manager_resolver1 = manager.get_resolver(mock_client1, loop1) manager_resolver2 = manager.get_resolver(mock_client2, loop2) # Should be different resolvers for different loops assert manager_resolver1 is resolver1 assert manager_resolver2 is resolver2 assert manager._loop_data[loop1][0] is resolver1 assert manager._loop_data[loop2][0] is resolver2 # Release the first resolver manager.release_resolver(mock_client1, loop1) # First loop's resolver should be gone, but second should remain assert loop1 not in manager._loop_data assert loop2 in manager._loop_data # Release the second resolver manager.release_resolver(mock_client2, loop2) # Both resolvers should be gone assert not manager._loop_data # Verify resolver cleanup resolver1.cancel.assert_called_once() resolver2.cancel.assert_called_once() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") async def test_dns_resolver_manager_weakref_garbage_collection() -> None: """Test that release_resolver handles None resolver due to weakref garbage collection.""" manager = _DNSResolverManager() # Create a mock resolver that will be None when accessed mock_resolver = Mock() mock_resolver.cancel = Mock() with patch("aiodns.DNSResolver", return_value=mock_resolver): # Create an AsyncResolver to get a resolver from the manager resolver = AsyncResolver() loop = asyncio.get_running_loop() # Manually corrupt the data to simulate garbage collection # by setting the resolver to None manager._loop_data[loop] = (None, manager._loop_data[loop][1]) # type: ignore[assignment] # This should not raise an AttributeError: 'NoneType' object has no attribute 'cancel' await resolver.close() # Verify no exception was raised and the loop data was cleaned up properly # Since we set resolver to None and there was one client, the entry should be removed assert loop not in manager._loop_data @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") async def test_dns_resolver_manager_missing_loop_data() -> None: """Test that release_resolver handles missing loop data gracefully.""" manager = _DNSResolverManager() with patch("aiodns.DNSResolver"): # Create an AsyncResolver resolver = AsyncResolver() loop = asyncio.get_running_loop() # Manually remove the loop data to simulate race condition manager._loop_data.clear() # This should not raise a KeyError await resolver.close() # Verify no exception was raised assert loop not in manager._loop_data @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_close_multiple_times() -> None: """Test that AsyncResolver.close() can be called multiple times without error.""" with patch("aiodns.DNSResolver") as mock_dns_resolver: mock_resolver = Mock() mock_resolver.cancel = Mock() mock_dns_resolver.return_value = mock_resolver # Create a resolver with custom args (dedicated resolver) resolver = AsyncResolver(nameservers=["8.8.8.8"]) # Close it once await resolver.close() mock_resolver.cancel.assert_called_once() # Close it again - should not raise AttributeError await resolver.close() # cancel should still only be called once mock_resolver.cancel.assert_called_once() @pytest.mark.skipif(not getaddrinfo, reason="aiodns >=3.2.0 required") @pytest.mark.usefixtures("check_no_lingering_resolvers") async def test_async_resolver_close_with_none_resolver() -> None: """Test that AsyncResolver.close() handles None resolver gracefully.""" with patch("aiodns.DNSResolver"): # Create a resolver with custom args (dedicated resolver) resolver = AsyncResolver(nameservers=["8.8.8.8"]) # Manually set resolver to None to simulate edge case resolver._resolver = None # type: ignore[assignment] # This should not raise AttributeError await resolver.close() ================================================ FILE: tests/test_route_def.py ================================================ import pathlib from typing import NoReturn import pytest from yarl import URL from aiohttp import web from aiohttp.web_urldispatcher import UrlDispatcher @pytest.fixture def router() -> UrlDispatcher: return UrlDispatcher() def test_get(router: UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False router.add_routes([web.get("/", handler)]) assert len(router.routes()) == 2 # GET and HEAD route = list(router.routes())[1] assert route.handler is handler assert route.method == "GET" assert str(route.url_for()) == "/" route2 = list(router.routes())[0] assert route2.handler is handler assert route2.method == "HEAD" def test_head(router: UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False router.add_routes([web.head("/", handler)]) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.handler is handler assert route.method == "HEAD" assert str(route.url_for()) == "/" def test_options(router: UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False router.add_routes([web.options("/", handler)]) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.handler is handler assert route.method == "OPTIONS" assert str(route.url_for()) == "/" def test_post(router: UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False router.add_routes([web.post("/", handler)]) route = list(router.routes())[0] assert route.handler is handler assert route.method == "POST" assert str(route.url_for()) == "/" def test_put(router: UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False router.add_routes([web.put("/", handler)]) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.handler is handler assert route.method == "PUT" assert str(route.url_for()) == "/" def test_patch(router: UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False router.add_routes([web.patch("/", handler)]) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.handler is handler assert route.method == "PATCH" assert str(route.url_for()) == "/" def test_delete(router: UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False router.add_routes([web.delete("/", handler)]) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.handler is handler assert route.method == "DELETE" assert str(route.url_for()) == "/" def test_route(router: UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False router.add_routes([web.route("OTHER", "/", handler)]) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.handler is handler assert route.method == "OTHER" assert str(route.url_for()) == "/" def test_static(router: UrlDispatcher) -> None: folder = pathlib.Path(__file__).parent router.add_routes([web.static("/prefix", folder)]) assert len(router.resources()) == 1 # 2 routes: for HEAD and GET resource = list(router.resources())[0] info = resource.get_info() assert info["prefix"] == "/prefix" assert info["directory"] == folder url = resource.url_for(filename="aiohttp.png") assert url == URL("/prefix/aiohttp.png") def test_head_deco(router: UrlDispatcher) -> None: routes = web.RouteTableDef() @routes.head("/path") async def handler(request: web.Request) -> NoReturn: assert False router.add_routes(routes) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.method == "HEAD" assert str(route.url_for()) == "/path" def test_get_deco(router: UrlDispatcher) -> None: routes = web.RouteTableDef() @routes.get("/path") async def handler(request: web.Request) -> NoReturn: assert False router.add_routes(routes) assert len(router.routes()) == 2 route1 = list(router.routes())[0] assert route1.method == "HEAD" assert str(route1.url_for()) == "/path" route2 = list(router.routes())[1] assert route2.method == "GET" assert str(route2.url_for()) == "/path" def test_post_deco(router: UrlDispatcher) -> None: routes = web.RouteTableDef() @routes.post("/path") async def handler(request: web.Request) -> NoReturn: assert False router.add_routes(routes) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.method == "POST" assert str(route.url_for()) == "/path" def test_put_deco(router: UrlDispatcher) -> None: routes = web.RouteTableDef() @routes.put("/path") async def handler(request: web.Request) -> NoReturn: assert False router.add_routes(routes) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.method == "PUT" assert str(route.url_for()) == "/path" def test_patch_deco(router: UrlDispatcher) -> None: routes = web.RouteTableDef() @routes.patch("/path") async def handler(request: web.Request) -> NoReturn: assert False router.add_routes(routes) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.method == "PATCH" assert str(route.url_for()) == "/path" def test_delete_deco(router: UrlDispatcher) -> None: routes = web.RouteTableDef() @routes.delete("/path") async def handler(request: web.Request) -> NoReturn: assert False router.add_routes(routes) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.method == "DELETE" assert str(route.url_for()) == "/path" def test_options_deco(router: UrlDispatcher) -> None: routes = web.RouteTableDef() @routes.options("/path") async def handler(request: web.Request) -> NoReturn: assert False router.add_routes(routes) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.method == "OPTIONS" assert str(route.url_for()) == "/path" def test_route_deco(router: UrlDispatcher) -> None: routes = web.RouteTableDef() @routes.route("OTHER", "/path") async def handler(request: web.Request) -> NoReturn: assert False router.add_routes(routes) assert len(router.routes()) == 1 route = list(router.routes())[0] assert route.method == "OTHER" assert str(route.url_for()) == "/path" def test_routedef_sequence_protocol() -> None: routes = web.RouteTableDef() @routes.delete("/path") async def handler(request: web.Request) -> NoReturn: assert False assert len(routes) == 1 info = routes[0] assert isinstance(info, web.RouteDef) assert info in routes assert list(routes)[0] is info def test_repr_route_def() -> None: routes = web.RouteTableDef() @routes.get("/path") async def handler(request: web.Request) -> NoReturn: assert False rd = routes[0] assert repr(rd) == " 'handler'>" def test_repr_route_def_with_extra_info() -> None: routes = web.RouteTableDef() @routes.get("/path", extra="info") async def handler(request: web.Request) -> NoReturn: assert False rd = routes[0] assert repr(rd) == " 'handler', extra='info'>" def test_repr_static_def() -> None: routes = web.RouteTableDef() routes.static("/prefix", "/path", name="name") rd = routes[0] assert repr(rd) == " /path, name='name'>" def test_repr_route_table_def() -> None: routes = web.RouteTableDef() @routes.get("/path") async def handler(request: web.Request) -> NoReturn: assert False assert repr(routes) == "" ================================================ FILE: tests/test_run_app.py ================================================ import asyncio import contextlib import logging import os import platform import signal import socket import ssl import subprocess import sys import time from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine, Iterator from typing import Any, NoReturn from unittest import mock from uuid import uuid4 import pytest from pytest_mock import MockerFixture from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web from aiohttp.log import access_logger from aiohttp.web_protocol import RequestHandler from aiohttp.web_runner import BaseRunner _has_unix_domain_socks = hasattr(socket, "AF_UNIX") if _has_unix_domain_socks: with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as _abstract_path_sock: try: _abstract_path_sock.bind(b"\x00" + uuid4().hex.encode("ascii")) except FileNotFoundError: _abstract_path_failed = True else: _abstract_path_failed = False finally: del _abstract_path_sock else: _abstract_path_failed = True skip_if_no_abstract_paths = pytest.mark.skipif( _abstract_path_failed, reason="Linux-style abstract paths are not supported." ) skip_if_no_unix_socks = pytest.mark.skipif( not _has_unix_domain_socks, reason="Unix domain sockets are not supported" ) del _has_unix_domain_socks, _abstract_path_failed HAS_IPV6: bool = socket.has_ipv6 if HAS_IPV6: # pragma: no branch # The socket.has_ipv6 flag may be True if Python was built with IPv6 # support, but the target system still may not have it. # So let's ensure that we really have IPv6 support. try: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM): pass except OSError: # pragma: no cover HAS_IPV6 = False def skip_if_on_windows() -> None: if platform.system() == "Windows": pytest.skip("the test is not valid for Windows") @pytest.fixture def patched_loop( loop: asyncio.AbstractEventLoop, ) -> Iterator[asyncio.AbstractEventLoop]: server = mock.create_autospec(asyncio.Server, spec_set=True, instance=True) server.wait_closed.return_value = None server.sockets = [] unix_server = mock.create_autospec(asyncio.Server, spec_set=True, instance=True) unix_server.wait_closed.return_value = None unix_server.sockets = [] with mock.patch.object( loop, "create_server", autospec=True, spec_set=True, return_value=server ): with mock.patch.object( loop, "create_unix_server", autospec=True, spec_set=True, return_value=unix_server, ): asyncio.set_event_loop(loop) yield loop def stopper(loop: asyncio.AbstractEventLoop) -> Callable[[], None]: def raiser() -> NoReturn: raise KeyboardInterrupt def f(*args: object) -> None: loop.call_soon(raiser) return f def test_run_app_http(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() startup_handler = mock.AsyncMock() app.on_startup.append(startup_handler) cleanup_handler = mock.AsyncMock() app.on_cleanup.append(cleanup_handler) web.run_app(app, print=stopper(patched_loop), loop=patched_loop) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None ) startup_handler.assert_called_once_with(app) cleanup_handler.assert_called_once_with(app) def test_run_app_close_loop(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() web.run_app(app, print=stopper(patched_loop), loop=patched_loop) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None ) assert patched_loop.is_closed() mock_unix_server_single = [ mock.call(mock.ANY, "/tmp/testsock1.sock", ssl=None, backlog=128), ] mock_unix_server_multi = [ mock.call(mock.ANY, "/tmp/testsock1.sock", ssl=None, backlog=128), mock.call(mock.ANY, "/tmp/testsock2.sock", ssl=None, backlog=128), ] mock_server_single = [ mock.call( mock.ANY, "127.0.0.1", 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ), ] mock_server_multi = [ mock.call( mock.ANY, "127.0.0.1", 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ), mock.call( mock.ANY, "192.168.1.1", 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ), ] mock_server_default_8989 = [ mock.call( mock.ANY, None, 8989, ssl=None, backlog=128, reuse_address=None, reuse_port=None ) ] mock_socket = mock.Mock(getsockname=lambda: ("mock-socket", 123)) mixed_bindings_tests: tuple[ tuple[str, dict[str, Any], list[mock._Call], list[mock._Call]], ... ] = ( ( "Nothing Specified", {}, [ mock.call( mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ) ], [], ), ("Port Only", {"port": 8989}, mock_server_default_8989, []), ("Multiple Hosts", {"host": ("127.0.0.1", "192.168.1.1")}, mock_server_multi, []), ( "Multiple Paths", {"path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock")}, [], mock_unix_server_multi, ), ( "Multiple Paths, Port", {"path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), "port": 8989}, mock_server_default_8989, mock_unix_server_multi, ), ( "Multiple Paths, Single Host", {"path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), "host": "127.0.0.1"}, mock_server_single, mock_unix_server_multi, ), ( "Single Path, Single Host", {"path": "/tmp/testsock1.sock", "host": "127.0.0.1"}, mock_server_single, mock_unix_server_single, ), ( "Single Path, Multiple Hosts", {"path": "/tmp/testsock1.sock", "host": ("127.0.0.1", "192.168.1.1")}, mock_server_multi, mock_unix_server_single, ), ( "Single Path, Port", {"path": "/tmp/testsock1.sock", "port": 8989}, mock_server_default_8989, mock_unix_server_single, ), ( "Multiple Paths, Multiple Hosts, Port", { "path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), "host": ("127.0.0.1", "192.168.1.1"), "port": 8000, }, [ mock.call( mock.ANY, "127.0.0.1", 8000, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ), mock.call( mock.ANY, "192.168.1.1", 8000, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ), ], mock_unix_server_multi, ), ( "Only socket", {"sock": [mock_socket]}, [mock.call(mock.ANY, ssl=None, sock=mock_socket, backlog=128)], [], ), ( "Socket, port", {"sock": [mock_socket], "port": 8765}, [ mock.call( mock.ANY, None, 8765, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ), mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128), ], [], ), ( "Socket, Host, No port", {"sock": [mock_socket], "host": "localhost"}, [ mock.call( mock.ANY, "localhost", 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ), mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128), ], [], ), ( "reuse_port", {"reuse_port": True}, [ mock.call( mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=True, ) ], [], ), ( "reuse_address", {"reuse_address": False}, [ mock.call( mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=False, reuse_port=None, ) ], [], ), ( "reuse_port, reuse_address", {"reuse_address": True, "reuse_port": True}, [ mock.call( mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=True, reuse_port=True, ) ], [], ), ( "Port, reuse_port", {"port": 8989, "reuse_port": True}, [ mock.call( mock.ANY, None, 8989, ssl=None, backlog=128, reuse_address=None, reuse_port=True, ) ], [], ), ( "Multiple Hosts, reuse_port", {"host": ("127.0.0.1", "192.168.1.1"), "reuse_port": True}, [ mock.call( mock.ANY, "127.0.0.1", 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=True, ), mock.call( mock.ANY, "192.168.1.1", 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=True, ), ], [], ), ( "Multiple Paths, Port, reuse_address", { "path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), "port": 8989, "reuse_address": False, }, [ mock.call( mock.ANY, None, 8989, ssl=None, backlog=128, reuse_address=False, reuse_port=None, ) ], mock_unix_server_multi, ), ( "Multiple Paths, Single Host, reuse_address, reuse_port", { "path": ("/tmp/testsock1.sock", "/tmp/testsock2.sock"), "host": "127.0.0.1", "reuse_address": True, "reuse_port": True, }, [ mock.call( mock.ANY, "127.0.0.1", 8080, ssl=None, backlog=128, reuse_address=True, reuse_port=True, ), ], mock_unix_server_multi, ), ) mixed_bindings_test_ids = [test[0] for test in mixed_bindings_tests] mixed_bindings_test_params = [test[1:] for test in mixed_bindings_tests] @pytest.mark.parametrize( "run_app_kwargs, expected_server_calls, expected_unix_server_calls", mixed_bindings_test_params, ids=mixed_bindings_test_ids, ) def test_run_app_mixed_bindings( # type: ignore[misc] run_app_kwargs: dict[str, Any], expected_server_calls: list[mock._Call], expected_unix_server_calls: list[mock._Call], patched_loop: asyncio.AbstractEventLoop, ) -> None: app = web.Application() web.run_app(app, print=stopper(patched_loop), **run_app_kwargs, loop=patched_loop) assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls # type: ignore[attr-defined] assert patched_loop.create_server.mock_calls == expected_server_calls # type: ignore[attr-defined] def test_run_app_https(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() ssl_context = ssl.create_default_context() web.run_app( app, ssl_context=ssl_context, print=stopper(patched_loop), loop=patched_loop ) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, None, 8443, ssl=ssl_context, backlog=128, reuse_address=None, reuse_port=None, ) def test_run_app_nondefault_host_port( patched_loop: asyncio.AbstractEventLoop, unused_port_socket: socket.socket ) -> None: port = unused_port_socket.getsockname()[1] host = "127.0.0.1" app = web.Application() web.run_app( app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop ) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None ) def test_run_app_with_sock( patched_loop: asyncio.AbstractEventLoop, unused_port_socket: socket.socket ) -> None: sock = unused_port_socket app = web.Application() web.run_app( app, sock=sock, print=stopper(patched_loop), loop=patched_loop, ) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, sock=sock, ssl=None, backlog=128 ) def test_run_app_multiple_hosts(patched_loop: asyncio.AbstractEventLoop) -> None: hosts = ("127.0.0.1", "127.0.0.2") app = web.Application() web.run_app(app, host=hosts, print=stopper(patched_loop), loop=patched_loop) calls = map( lambda h: mock.call( mock.ANY, h, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ), hosts, ) patched_loop.create_server.assert_has_calls(calls) # type: ignore[attr-defined] def test_run_app_custom_backlog(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() web.run_app(app, backlog=10, print=stopper(patched_loop), loop=patched_loop) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None ) def test_run_app_custom_backlog_unix(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() web.run_app( app, path="/tmp/tmpsock.sock", backlog=10, print=stopper(patched_loop), loop=patched_loop, ) patched_loop.create_unix_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, "/tmp/tmpsock.sock", ssl=None, backlog=10 ) @skip_if_no_unix_socks def test_run_app_http_unix_socket( patched_loop: asyncio.AbstractEventLoop, unix_sockname: str ) -> None: app = web.Application() printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, path=unix_sockname, print=printer, loop=patched_loop) patched_loop.create_unix_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, unix_sockname, ssl=None, backlog=128 ) assert f"http://unix:{unix_sockname}:" in printer.call_args[0][0] @skip_if_no_unix_socks def test_run_app_https_unix_socket( patched_loop: asyncio.AbstractEventLoop, unix_sockname: str ) -> None: app = web.Application() ssl_context = ssl.create_default_context() printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app( app, path=unix_sockname, ssl_context=ssl_context, print=printer, loop=patched_loop, ) patched_loop.create_unix_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, unix_sockname, ssl=ssl_context, backlog=128 ) assert f"https://unix:{unix_sockname}:" in printer.call_args[0][0] @pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") @skip_if_no_abstract_paths def test_run_app_abstract_linux_socket(patched_loop: asyncio.AbstractEventLoop) -> None: sock_path = b"\x00" + uuid4().hex.encode("ascii") app = web.Application() web.run_app( app, path=sock_path.decode("ascii", "ignore"), print=stopper(patched_loop), loop=patched_loop, ) patched_loop.create_unix_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, sock_path.decode("ascii"), ssl=None, backlog=128 ) def test_run_app_preexisting_inet_socket( patched_loop: asyncio.AbstractEventLoop, mocker: MockerFixture ) -> None: app = web.Application() sock = socket.socket() with contextlib.closing(sock): sock.bind(("127.0.0.1", 0)) _, port = sock.getsockname() printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://127.0.0.1:{port}" in printer.call_args[0][0] @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 is not available") def test_run_app_preexisting_inet6_socket( patched_loop: asyncio.AbstractEventLoop, ) -> None: app = web.Application() sock = socket.socket(socket.AF_INET6) with contextlib.closing(sock): sock.bind(("::1", 0)) port = sock.getsockname()[1] printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://[::1]:{port}" in printer.call_args[0][0] @skip_if_no_unix_socks def test_run_app_preexisting_unix_socket( patched_loop: asyncio.AbstractEventLoop, unix_sockname: str, mocker: MockerFixture ) -> None: app = web.Application() sock = socket.socket(socket.AF_UNIX) with contextlib.closing(sock): sock.bind(unix_sockname) os.unlink(unix_sockname) printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://unix:{unix_sockname}:" in printer.call_args[0][0] def test_run_app_multiple_preexisting_sockets( patched_loop: asyncio.AbstractEventLoop, ) -> None: app = web.Application() sock1 = socket.socket() sock2 = socket.socket() with contextlib.closing(sock1), contextlib.closing(sock2): sock1.bind(("localhost", 0)) _, port1 = sock1.getsockname() sock2.bind(("localhost", 0)) _, port2 = sock2.getsockname() printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=(sock1, sock2), print=printer, loop=patched_loop) patched_loop.create_server.assert_has_calls( # type: ignore[attr-defined] [ mock.call(mock.ANY, sock=sock1, backlog=128, ssl=None), mock.call(mock.ANY, sock=sock2, backlog=128, ssl=None), ] ) assert f"http://127.0.0.1:{port1}" in printer.call_args[0][0] assert f"http://127.0.0.1:{port2}" in printer.call_args[0][0] _script_test_signal = """ from aiohttp import web app = web.Application() web.run_app(app, host=()) """ def test_sigint() -> None: skip_if_on_windows() with subprocess.Popen( (sys.executable, "-u", "-c", _script_test_signal), stdout=subprocess.PIPE, ) as proc: assert proc.stdout.readline().startswith(b"======== Running on") # type: ignore[union-attr] proc.send_signal(signal.SIGINT) assert proc.wait() == 0 def test_sigterm() -> None: skip_if_on_windows() with subprocess.Popen( (sys.executable, "-u", "-c", _script_test_signal), stdout=subprocess.PIPE, ) as proc: assert proc.stdout.readline().startswith(b"======== Running on") # type: ignore[union-attr] proc.terminate() assert proc.wait() == 0 def test_startup_cleanup_signals_even_on_failure( patched_loop: asyncio.AbstractEventLoop, ) -> None: patched_loop.create_server.side_effect = RuntimeError() # type: ignore[attr-defined] app = web.Application() startup_handler = mock.AsyncMock() app.on_startup.append(startup_handler) cleanup_handler = mock.AsyncMock() app.on_cleanup.append(cleanup_handler) with pytest.raises(RuntimeError): web.run_app(app, print=stopper(patched_loop), loop=patched_loop) startup_handler.assert_called_once_with(app) cleanup_handler.assert_called_once_with(app) def test_run_app_coro(patched_loop: asyncio.AbstractEventLoop) -> None: startup_handler = cleanup_handler = None async def make_app() -> web.Application: nonlocal startup_handler, cleanup_handler app = web.Application() startup_handler = mock.AsyncMock() app.on_startup.append(startup_handler) cleanup_handler = mock.AsyncMock() app.on_cleanup.append(cleanup_handler) return app web.run_app(make_app(), print=stopper(patched_loop), loop=patched_loop) patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None ) assert startup_handler is not None assert cleanup_handler is not None startup_handler.assert_called_once_with(mock.ANY) cleanup_handler.assert_called_once_with(mock.ANY) def test_run_app_default_logger( monkeypatch: pytest.MonkeyPatch, patched_loop: asyncio.AbstractEventLoop ) -> None: logger = access_logger attrs = { "hasHandlers.return_value": False, "level": logging.NOTSET, "name": "aiohttp.access", } mock_logger = mock.create_autospec(logger, name="mock_access_logger") mock_logger.configure_mock(**attrs) app = web.Application() web.run_app( app, debug=True, print=stopper(patched_loop), access_log=mock_logger, loop=patched_loop, ) mock_logger.setLevel.assert_any_call(logging.DEBUG) mock_logger.hasHandlers.assert_called_with() assert isinstance(mock_logger.addHandler.call_args[0][0], logging.StreamHandler) def test_run_app_default_logger_setup_requires_debug( patched_loop: asyncio.AbstractEventLoop, ) -> None: logger = access_logger attrs = { "hasHandlers.return_value": False, "level": logging.NOTSET, "name": "aiohttp.access", } mock_logger = mock.create_autospec(logger, name="mock_access_logger") mock_logger.configure_mock(**attrs) app = web.Application() web.run_app( app, debug=False, print=stopper(patched_loop), access_log=mock_logger, loop=patched_loop, ) mock_logger.setLevel.assert_not_called() mock_logger.hasHandlers.assert_not_called() mock_logger.addHandler.assert_not_called() def test_run_app_default_logger_setup_requires_default_logger( patched_loop: asyncio.AbstractEventLoop, ) -> None: logger = access_logger attrs = { "hasHandlers.return_value": False, "level": logging.NOTSET, "name": None, } mock_logger = mock.create_autospec(logger, name="mock_access_logger") mock_logger.configure_mock(**attrs) app = web.Application() web.run_app( app, debug=True, print=stopper(patched_loop), access_log=mock_logger, loop=patched_loop, ) mock_logger.setLevel.assert_not_called() mock_logger.hasHandlers.assert_not_called() mock_logger.addHandler.assert_not_called() def test_run_app_default_logger_setup_only_if_unconfigured( patched_loop: asyncio.AbstractEventLoop, ) -> None: logger = access_logger attrs = { "hasHandlers.return_value": True, "level": None, "name": "aiohttp.access", } mock_logger = mock.create_autospec(logger, name="mock_access_logger") mock_logger.configure_mock(**attrs) app = web.Application() web.run_app( app, debug=True, print=stopper(patched_loop), access_log=mock_logger, loop=patched_loop, ) mock_logger.setLevel.assert_not_called() mock_logger.hasHandlers.assert_called_with() mock_logger.addHandler.assert_not_called() def test_run_app_cancels_all_pending_tasks( patched_loop: asyncio.AbstractEventLoop, ) -> None: app = web.Application() task = None async def on_startup(app: web.Application) -> None: nonlocal task loop = asyncio.get_event_loop() task = loop.create_task(asyncio.sleep(1000)) app.on_startup.append(on_startup) web.run_app(app, print=stopper(patched_loop), loop=patched_loop) assert task is not None assert task.cancelled() def test_run_app_cancels_done_tasks(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() task = None async def coro() -> int: return 123 async def on_startup(app: web.Application) -> None: nonlocal task loop = asyncio.get_event_loop() task = loop.create_task(coro()) app.on_startup.append(on_startup) web.run_app(app, print=stopper(patched_loop), loop=patched_loop) assert task is not None assert task.done() def test_run_app_cancels_failed_tasks(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() task = None exc = RuntimeError("FAIL") async def fail() -> None: try: await asyncio.sleep(1000) except asyncio.CancelledError: raise exc async def on_startup(app: web.Application) -> None: nonlocal task loop = asyncio.get_event_loop() task = loop.create_task(fail()) await asyncio.sleep(0.01) app.on_startup.append(on_startup) exc_handler = mock.Mock() patched_loop.set_exception_handler(exc_handler) web.run_app(app, print=stopper(patched_loop), loop=patched_loop) assert task is not None assert task.done() msg = { "message": "unhandled exception during asyncio.run() shutdown", "exception": exc, "task": task, } exc_handler.assert_called_with(patched_loop, msg) @pytest.mark.parametrize( "param", ( "keepalive_timeout", "max_line_size", "max_headers", "max_field_size", "lingering_time", "read_bufsize", "auto_decompress", ), ) def test_run_app_pass_apprunner_kwargs( param: str, patched_loop: asyncio.AbstractEventLoop, monkeypatch: pytest.MonkeyPatch, ) -> None: m = mock.Mock() base_runner_init_orig = BaseRunner.__init__ def base_runner_init_spy( self: BaseRunner[web.Request], *args: Any, **kwargs: Any ) -> None: assert kwargs[param] is m base_runner_init_orig(self, *args, **kwargs) app = web.Application() monkeypatch.setattr(BaseRunner, "__init__", base_runner_init_spy) web.run_app(app, print=stopper(patched_loop), loop=patched_loop, **{param: m}) def test_run_app_context_vars(patched_loop: asyncio.AbstractEventLoop) -> None: from contextvars import ContextVar count = 0 VAR = ContextVar("VAR", default="default") async def on_startup(app: web.Application) -> None: nonlocal count assert "init" == VAR.get() VAR.set("on_startup") count += 1 async def on_cleanup(app: web.Application) -> None: nonlocal count assert "on_startup" == VAR.get() count += 1 async def init() -> web.Application: nonlocal count assert "default" == VAR.get() VAR.set("init") app = web.Application() app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) count += 1 return app web.run_app(init(), print=stopper(patched_loop), loop=patched_loop) assert count == 3 def test_run_app_raises_exception(patched_loop: asyncio.AbstractEventLoop) -> None: async def context(app: web.Application) -> AsyncIterator[None]: raise RuntimeError("foo") yield # type: ignore[unreachable] # pragma: no cover app = web.Application() app.cleanup_ctx.append(context) with mock.patch.object( patched_loop, "call_exception_handler", autospec=True, spec_set=True ) as m: with pytest.raises(RuntimeError, match="foo"): web.run_app(app, loop=patched_loop) assert not m.called class TestShutdown: def raiser(self) -> NoReturn: raise KeyboardInterrupt async def stop(self, request: web.Request) -> web.Response: asyncio.get_running_loop().call_soon(self.raiser) return web.Response() def run_app( self, sock: socket.socket, timeout: int, task: Callable[[], Coroutine[None, None, None]], extra_test: Callable[[ClientSession], Awaitable[None]] | None = None, ) -> tuple["asyncio.Task[None]", int]: num_connections = -1 t = test_task = None port = sock.getsockname()[1] class DictRecordClear(dict[RequestHandler[web.Request], asyncio.Transport]): def clear(self) -> None: nonlocal num_connections # During Server.shutdown() we want to know how many connections still # remained before it got cleared. If the handler completed successfully # the connection should've been removed already. If not, this may # indicate a memory leak. num_connections = len(self) super().clear() class ServerWithRecordClear(web.Server[web.Request]): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._connections = DictRecordClear() async def test() -> None: await asyncio.sleep(0.5) async with ClientSession() as sess: for _ in range(5): # Retry for flaky tests # pragma: no cover try: with pytest.raises(asyncio.TimeoutError): async with sess.get( f"http://127.0.0.1:{port}/", timeout=ClientTimeout(total=0.1), ): pass except ClientConnectorError: await asyncio.sleep(0.5) else: break async with sess.get(f"http://127.0.0.1:{port}/stop"): pass if extra_test: await extra_test(sess) async def run_test(app: web.Application) -> AsyncIterator[None]: nonlocal test_task test_task = asyncio.create_task(test()) yield await test_task async def handler(request: web.Request) -> web.Response: nonlocal t t = asyncio.create_task(task()) await t return web.Response(text="FOO") app = web.Application() app.cleanup_ctx.append(run_test) app.router.add_get("/", handler) app.router.add_get("/stop", self.stop) with mock.patch("aiohttp.web_runner.Server", ServerWithRecordClear): web.run_app(app, sock=sock, shutdown_timeout=timeout) assert test_task is not None assert test_task.exception() is None assert t is not None return t, num_connections def test_shutdown_wait_for_handler(self, unused_port_socket: socket.socket) -> None: sock = unused_port_socket finished = False async def task() -> None: nonlocal finished await asyncio.sleep(2) finished = True t, connection_count = self.run_app(sock, 3, task) assert finished is True assert t.done() assert not t.cancelled() assert connection_count == 0 def test_shutdown_timeout_handler(self, unused_port_socket: socket.socket) -> None: sock = unused_port_socket finished = False async def task() -> None: nonlocal finished await asyncio.sleep(2) finished = True # pragma: no cover t, connection_count = self.run_app(sock, 1, task) assert finished is False assert t.done() assert t.cancelled() assert connection_count == 1 def test_shutdown_timeout_not_reached( self, unused_port_socket: socket.socket ) -> None: sock = unused_port_socket finished = False async def task() -> None: nonlocal finished await asyncio.sleep(1) finished = True start_time = time.time() t, connection_count = self.run_app(sock, 15, task) assert finished is True assert t.done() assert connection_count == 0 # Verify run_app has not waited for timeout. assert time.time() - start_time < 10 def test_shutdown_new_conn_rejected( self, unused_port_socket: socket.socket ) -> None: sock = unused_port_socket port = sock.getsockname()[1] finished = False async def task() -> None: nonlocal finished await asyncio.sleep(9) finished = True async def test(sess: ClientSession) -> None: # Ensure we are in the middle of shutdown (waiting for task()). await asyncio.sleep(1) with pytest.raises(ClientConnectorError): # Use a new session to try and open a new connection. async with ClientSession() as sess: async with sess.get(f"http://127.0.0.1:{port}/"): assert False # Should fail before here assert finished is False t, connection_count = self.run_app(sock, 10, task, test) assert finished is True assert t.done() assert connection_count == 0 def test_shutdown_pending_handler_responds( self, unused_port_socket: socket.socket ) -> None: sock = unused_port_socket port = sock.getsockname()[1] finished = False t = None async def test() -> None: async def test_resp(sess: ClientSession) -> None: async with sess.get(f"http://127.0.0.1:{port}/") as resp: assert await resp.text() == "FOO" await asyncio.sleep(1) async with ClientSession() as sess: t = asyncio.create_task(test_resp(sess)) await asyncio.sleep(1) # Handler is in-progress while we trigger server shutdown. async with sess.get(f"http://127.0.0.1:{port}/stop"): pass assert finished is False # Handler should still complete and produce a response. await t async def run_test(app: web.Application) -> AsyncIterator[None]: nonlocal t t = asyncio.create_task(test()) yield await t async def handler(request: web.Request) -> web.Response: nonlocal finished await asyncio.sleep(3) finished = True return web.Response(text="FOO") app = web.Application() app.cleanup_ctx.append(run_test) app.router.add_get("/", handler) app.router.add_get("/stop", self.stop) web.run_app(app, sock=sock, shutdown_timeout=5) assert t is not None assert t.exception() is None assert finished is True def test_shutdown_close_idle_keepalive( self, unused_port_socket: socket.socket ) -> None: sock = unused_port_socket port = sock.getsockname()[1] t = None async def test() -> None: await asyncio.sleep(1) async with ClientSession() as sess: async with sess.get(f"http://127.0.0.1:{port}/stop"): pass # Hold on to keep-alive connection. await asyncio.sleep(5) async def run_test(app: web.Application) -> AsyncIterator[None]: nonlocal t t = asyncio.create_task(test()) yield t.cancel() with contextlib.suppress(asyncio.CancelledError): await t app = web.Application() app.cleanup_ctx.append(run_test) app.router.add_get("/stop", self.stop) web.run_app(app, sock=sock, shutdown_timeout=10) # If connection closed, then test() will be cancelled in cleanup_ctx. # If not, then shutdown_timeout will allow it to sleep until complete. assert t is not None assert t.cancelled() def test_shutdown_close_websockets(self, unused_port_socket: socket.socket) -> None: sock = unused_port_socket port = sock.getsockname()[1] WS = web.AppKey("ws", set[web.WebSocketResponse]) client_finished = server_finished = False t = None async def ws_handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) request.app[WS].add(ws) async for msg in ws: assert False # No messages actually sent nonlocal server_finished server_finished = True return ws async def close_websockets(app: web.Application) -> None: for ws in app[WS]: await ws.close(code=WSCloseCode.GOING_AWAY) async def test() -> None: await asyncio.sleep(1) async with ClientSession() as sess: async with sess.ws_connect(f"http://127.0.0.1:{port}/ws") as ws: async with sess.get(f"http://127.0.0.1:{port}/stop"): pass async for msg in ws: assert False # No messages actually sent nonlocal client_finished client_finished = True async def run_test(app: web.Application) -> AsyncIterator[None]: nonlocal t t = asyncio.create_task(test()) yield await asyncio.sleep(0) # In case test() hasn't resumed yet. t.cancel() with contextlib.suppress(asyncio.CancelledError): await t app = web.Application() app[WS] = set() app.on_shutdown.append(close_websockets) app.cleanup_ctx.append(run_test) app.router.add_get("/ws", ws_handler) app.router.add_get("/stop", self.stop) start = time.time() web.run_app(app, sock=sock, shutdown_timeout=10) assert time.time() - start < 5 assert client_finished assert server_finished def test_shutdown_handler_cancellation_suppressed( self, unused_port_socket: socket.socket ) -> None: sock = unused_port_socket port = sock.getsockname()[1] actions = [] t = None async def test() -> None: async def test_resp(sess: ClientSession) -> None: t = ClientTimeout(total=0.4) with pytest.raises(asyncio.TimeoutError): async with sess.get(f"http://127.0.0.1:{port}/", timeout=t): assert False # Should timeout before this actions.append("CANCELLED") async with ClientSession() as sess: t = asyncio.create_task(test_resp(sess)) await asyncio.sleep(0.5) # Handler is in-progress while we trigger server shutdown. actions.append("PRESTOP") async with sess.get(f"http://127.0.0.1:{port}/stop"): pass actions.append("STOPPING") # Handler should still complete and produce a response. await t async def run_test(app: web.Application) -> AsyncIterator[None]: nonlocal t t = asyncio.create_task(test()) yield await t async def handler(request: web.Request) -> web.Response: try: await asyncio.sleep(5) except asyncio.CancelledError: actions.append("SUPPRESSED") await asyncio.sleep(2) actions.append("DONE") return web.Response(text="FOO") app = web.Application() app.cleanup_ctx.append(run_test) app.router.add_get("/", handler) app.router.add_get("/stop", self.stop) web.run_app(app, sock=sock, shutdown_timeout=2, handler_cancellation=True) assert t is not None assert t.exception() is None assert actions == ["CANCELLED", "SUPPRESSED", "PRESTOP", "STOPPING", "DONE"] ================================================ FILE: tests/test_streams.py ================================================ # Tests for streams.py import abc import asyncio import gc import types from collections import defaultdict from collections.abc import Iterator, Sequence from itertools import groupby from typing import TypeVar from unittest import mock import pytest from aiohttp import streams from aiohttp.base_protocol import BaseProtocol from aiohttp.http_exceptions import LineTooLong DATA: bytes = b"line1\nline2\nline3\n" _T = TypeVar("_T") def chunkify(seq: Sequence[_T], n: int) -> Iterator[Sequence[_T]]: for i in range(0, len(seq), n): yield seq[i : i + n] async def create_stream() -> streams.StreamReader: loop = asyncio.get_event_loop() protocol = mock.Mock(_reading_paused=False) stream = streams.StreamReader(protocol, 2**16, loop=loop) stream.feed_data(DATA) stream.feed_eof() return stream @pytest.fixture def protocol() -> mock.Mock: return mock.Mock(_reading_paused=False) MEMLEAK_SKIP_TYPES = ( *(getattr(types, name) for name in dir(types) if name.endswith("Type")), mock.Mock, abc.ABCMeta, ) def get_memory_usage(obj: object) -> int: objs = [obj] # Memory leak may be caused by leaked links to same objects. # Without link counting, [1,2,3] is indistinguishable from [1,2,3,3,3,3,3,3] known: defaultdict[int, int] = defaultdict(int) known[id(obj)] += 1 while objs: refs = gc.get_referents(*objs) objs = [] for obj in refs: if isinstance(obj, MEMLEAK_SKIP_TYPES): continue i = id(obj) known[i] += 1 if known[i] == 1: objs.append(obj) # Make list of unhashable objects uniq objs.sort(key=id) objs = [next(g) for (i, g) in groupby(objs, id)] return sum(known.values()) class TestStreamReader: DATA: bytes = b"line1\nline2\nline3\n" def _make_one(self, limit: int = 2**16) -> streams.StreamReader: loop = asyncio.get_event_loop() return streams.StreamReader(mock.Mock(_reading_paused=False), limit, loop=loop) async def test_create_waiter(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() stream._waiter = loop.create_future # type: ignore[assignment] with pytest.raises(RuntimeError): await stream._wait("test") async def test_at_eof(self) -> None: stream = self._make_one() assert not stream.at_eof() stream.feed_data(b"some data\n") assert not stream.at_eof() await stream.readline() assert not stream.at_eof() stream.feed_data(b"some data\n") stream.feed_eof() await stream.readline() assert stream.at_eof() async def test_wait_eof(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() wait_task = loop.create_task(stream.wait_eof()) async def cb() -> None: await asyncio.sleep(0.1) stream.feed_eof() t = loop.create_task(cb()) await wait_task assert stream.is_eof() assert stream._eof_waiter is None await t async def test_wait_eof_eof(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() stream.feed_eof() wait_task = loop.create_task(stream.wait_eof()) await wait_task assert stream.is_eof() async def test_feed_empty_data(self) -> None: stream = self._make_one() stream.feed_data(b"") stream.feed_eof() data = await stream.read() assert b"" == data async def test_feed_nonempty_data(self) -> None: stream = self._make_one() stream.feed_data(self.DATA) stream.feed_eof() data = await stream.read() assert self.DATA == data async def test_read_zero(self) -> None: # Read zero bytes. stream = self._make_one() stream.feed_data(self.DATA) data = await stream.read(0) assert b"" == data stream.feed_eof() data = await stream.read() assert self.DATA == data async def test_read(self) -> None: loop = asyncio.get_event_loop() # Read bytes. stream = self._make_one() read_task = loop.create_task(stream.read(30)) def cb() -> None: stream.feed_data(self.DATA) loop.call_soon(cb) data = await read_task assert self.DATA == data stream.feed_eof() data = await stream.read() assert b"" == data async def test_read_line_breaks(self) -> None: # Read bytes without line breaks. stream = self._make_one() stream.feed_data(b"line1") stream.feed_data(b"line2") data = await stream.read(5) assert b"line1" == data data = await stream.read(5) assert b"line2" == data async def test_read_all(self) -> None: # Read all available buffered bytes stream = self._make_one() stream.feed_data(b"line1") stream.feed_data(b"line2") stream.feed_eof() data = await stream.read() assert b"line1line2" == data async def test_read_up_to(self) -> None: # Read available buffered bytes up to requested amount stream = self._make_one() stream.feed_data(b"line1") stream.feed_data(b"line2") data = await stream.read(8) assert b"line1lin" == data data = await stream.read(8) assert b"e2" == data async def test_read_eof(self) -> None: loop = asyncio.get_event_loop() # Read bytes, stop at eof. stream = self._make_one() read_task = loop.create_task(stream.read(1024)) def cb() -> None: stream.feed_eof() loop.call_soon(cb) data = await read_task assert b"" == data data = await stream.read() assert data == b"" async def test_read_eof_unread_data_no_warning(self) -> None: # Read bytes. stream = self._make_one() stream.feed_eof() with mock.patch("aiohttp.streams.internal_logger") as internal_logger: await stream.read() await stream.read() await stream.read() await stream.read() await stream.read() with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(b"data") await stream.read() await stream.read() assert not internal_logger.warning.called async def test_read_until_eof(self) -> None: loop = asyncio.get_event_loop() # Read all bytes until eof. stream = self._make_one() read_task = loop.create_task(stream.read(-1)) def cb() -> None: stream.feed_data(b"chunk1\n") stream.feed_data(b"chunk2") stream.feed_eof() loop.call_soon(cb) data = await read_task assert b"chunk1\nchunk2" == data data = await stream.read() assert b"" == data async def test_read_exception(self) -> None: stream = self._make_one() stream.feed_data(b"line\n") data = await stream.read(2) assert b"li" == data stream.set_exception(ValueError()) with pytest.raises(ValueError): await stream.read(2) async def test_readline(self) -> None: loop = asyncio.get_event_loop() # Read one line. 'readline' will need to wait for the data # to come from 'cb' stream = self._make_one() stream.feed_data(b"chunk1 ") read_task = loop.create_task(stream.readline()) def cb() -> None: stream.feed_data(b"chunk2 ") stream.feed_data(b"chunk3 ") stream.feed_data(b"\n chunk4") loop.call_soon(cb) line = await read_task assert b"chunk1 chunk2 chunk3 \n" == line stream.feed_eof() data = await stream.read() assert b" chunk4" == data async def test_readline_limit_with_existing_data(self) -> None: # Read one line. The data is in StreamReader's buffer # before the event loop is run. stream = self._make_one(limit=2) stream.feed_data(b"li") stream.feed_data(b"ne1\nline2\n") with pytest.raises(LineTooLong): await stream.readline() # The buffer should contain the remaining data after exception stream.feed_eof() data = await stream.read() assert b"line2\n" == data async def test_readline_limit(self) -> None: loop = asyncio.get_event_loop() # Read one line. StreamReaders are fed with data after # their 'readline' methods are called. stream = self._make_one(limit=4) def cb() -> None: stream.feed_data(b"chunk1") stream.feed_data(b"chunk2\n") stream.feed_data(b"chunk3\n") stream.feed_eof() loop.call_soon(cb) with pytest.raises(LineTooLong): await stream.readline() data = await stream.read() assert b"chunk3\n" == data async def test_readline_nolimit_nowait(self) -> None: # All needed data for the first 'readline' call will be # in the buffer. stream = self._make_one() stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) line = await stream.readline() assert b"line1\n" == line stream.feed_eof() data = await stream.read() assert b"line2\nline3\n" == data async def test_readline_eof(self) -> None: stream = self._make_one() stream.feed_data(b"some data") stream.feed_eof() line = await stream.readline() assert b"some data" == line async def test_readline_empty_eof(self) -> None: stream = self._make_one() stream.feed_eof() line = await stream.readline() assert b"" == line async def test_readline_read_byte_count(self) -> None: stream = self._make_one() stream.feed_data(self.DATA) await stream.readline() data = await stream.read(7) assert b"line2\nl" == data stream.feed_eof() data = await stream.read() assert b"ine3\n" == data async def test_readline_exception(self) -> None: stream = self._make_one() stream.feed_data(b"line\n") data = await stream.readline() assert b"line\n" == data stream.set_exception(ValueError()) with pytest.raises(ValueError): await stream.readline() @pytest.mark.parametrize("separator", [b"*", b"**"]) async def test_readuntil(self, separator: bytes) -> None: loop = asyncio.get_event_loop() # Read one chunk. 'readuntil' will need to wait for the data # to come from 'cb' stream = self._make_one() stream.feed_data(b"chunk1 ") read_task = loop.create_task(stream.readuntil(separator)) def cb() -> None: stream.feed_data(b"chunk2 ") stream.feed_data(b"chunk3 ") stream.feed_data(separator + b" chunk4") loop.call_soon(cb) line = await read_task assert b"chunk1 chunk2 chunk3 " + separator == line stream.feed_eof() data = await stream.read() assert b" chunk4" == data @pytest.mark.parametrize("separator", [b"&", b"&&"]) async def test_readuntil_limit_with_existing_data(self, separator: bytes) -> None: # Read one chunk. The data is in StreamReader's buffer # before the event loop is run. stream = self._make_one(limit=2) stream.feed_data(b"li") stream.feed_data(b"ne1" + separator + b"line2" + separator) with pytest.raises(LineTooLong): await stream.readuntil(separator) # The buffer should contain the remaining data after exception stream.feed_eof() data = await stream.read() assert b"line2" + separator == data @pytest.mark.parametrize("separator", [b"$", b"$$"]) async def test_readuntil_limit(self, separator: bytes) -> None: loop = asyncio.get_event_loop() # Read one chunk. StreamReaders are fed with data after # their 'readuntil' methods are called. stream = self._make_one(limit=4) def cb() -> None: stream.feed_data(b"chunk1") stream.feed_data(b"chunk2" + separator) stream.feed_data(b"chunk3#") stream.feed_eof() loop.call_soon(cb) with pytest.raises(LineTooLong): await stream.readuntil(separator) data = await stream.read() assert b"chunk3#" == data @pytest.mark.parametrize("separator", [b"!", b"!!"]) async def test_readuntil_nolimit_nowait(self, separator: bytes) -> None: # All needed data for the first 'readuntil' call will be # in the buffer. seplen = len(separator) stream = self._make_one() data = b"line1" + separator + b"line2" + separator + b"line3" + separator stream.feed_data(data[: 5 + seplen]) stream.feed_data(data[5 + seplen :]) line = await stream.readuntil(separator) assert b"line1" + separator == line stream.feed_eof() data = await stream.read() assert b"line2" + separator + b"line3" + separator == data @pytest.mark.parametrize("separator", [b"@", b"@@"]) async def test_readuntil_eof(self, separator: bytes) -> None: stream = self._make_one() stream.feed_data(b"some data") stream.feed_eof() line = await stream.readuntil(separator) assert b"some data" == line @pytest.mark.parametrize("separator", [b"@", b"@@"]) async def test_readuntil_empty_eof(self, separator: bytes) -> None: stream = self._make_one() stream.feed_eof() line = await stream.readuntil(separator) assert b"" == line @pytest.mark.parametrize("separator", [b"!", b"!!"]) async def test_readuntil_read_byte_count(self, separator: bytes) -> None: seplen = len(separator) stream = self._make_one() stream.feed_data( b"line1" + separator + b"line2" + separator + b"line3" + separator ) await stream.readuntil(separator) data = await stream.read(6 + seplen) assert b"line2" + separator + b"l" == data stream.feed_eof() data = await stream.read() assert b"ine3" + separator == data @pytest.mark.parametrize("separator", [b"#", b"##"]) async def test_readuntil_exception(self, separator: bytes) -> None: stream = self._make_one() stream.feed_data(b"line" + separator) data = await stream.readuntil(separator) assert b"line" + separator == data stream.set_exception(ValueError("Another exception")) with pytest.raises(ValueError, match="Another exception"): await stream.readuntil(separator) async def test_readexactly_zero_or_less(self) -> None: # Read exact number of bytes (zero or less). stream = self._make_one() stream.feed_data(self.DATA) data = await stream.readexactly(0) assert b"" == data stream.feed_eof() data = await stream.read() assert self.DATA == data stream = self._make_one() stream.feed_data(self.DATA) data = await stream.readexactly(-1) assert b"" == data stream.feed_eof() data = await stream.read() assert self.DATA == data async def test_readexactly(self) -> None: loop = asyncio.get_event_loop() # Read exact number of bytes. stream = self._make_one() n = 2 * len(self.DATA) read_task = loop.create_task(stream.readexactly(n)) def cb() -> None: stream.feed_data(self.DATA) stream.feed_data(self.DATA) stream.feed_data(self.DATA) loop.call_soon(cb) data = await read_task assert self.DATA + self.DATA == data stream.feed_eof() data = await stream.read() assert self.DATA == data async def test_readexactly_eof(self) -> None: loop = asyncio.get_event_loop() # Read exact number of bytes (eof). stream = self._make_one() n = 2 * len(self.DATA) read_task = loop.create_task(stream.readexactly(n)) def cb() -> None: stream.feed_data(self.DATA) stream.feed_eof() loop.call_soon(cb) with pytest.raises(asyncio.IncompleteReadError) as cm: await read_task assert cm.value.partial == self.DATA assert cm.value.expected == n assert str(cm.value) == "18 bytes read on a total of 36 expected bytes" data = await stream.read() assert b"" == data async def test_readexactly_exception(self) -> None: stream = self._make_one() stream.feed_data(b"line\n") data = await stream.readexactly(2) assert b"li" == data stream.set_exception(ValueError()) with pytest.raises(ValueError): await stream.readexactly(2) async def test_unread_data(self) -> None: stream = self._make_one() stream.feed_data(b"line1") stream.feed_data(b"line2") stream.feed_data(b"onemoreline") data = await stream.read(5) assert b"line1" == data with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(data) data = await stream.read(5) assert b"line1" == data data = await stream.read(4) assert b"line" == data with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(b"line1line") data = b"" while len(data) < 10: data += await stream.read(10) assert b"line1line2" == data data = await stream.read(7) assert b"onemore" == data with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(data) data = b"" while len(data) < 11: data += await stream.read(11) assert b"onemoreline" == data with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(b"line") data = await stream.read(4) assert b"line" == data stream.feed_eof() with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(b"at_eof") data = await stream.read(6) assert b"at_eof" == data async def test_exception(self) -> None: stream = self._make_one() assert stream.exception() is None exc = ValueError() stream.set_exception(exc) assert stream.exception() is exc async def test_exception_waiter(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() async def set_err() -> None: stream.set_exception(ValueError()) t1 = loop.create_task(stream.readline()) t2 = loop.create_task(set_err()) await asyncio.wait((t1, t2)) with pytest.raises(ValueError): t1.result() async def test_exception_cancel(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() async def read_a_line() -> None: await stream.readline() t = loop.create_task(read_a_line()) await asyncio.sleep(0) t.cancel() await asyncio.sleep(0) # The following line fails if set_exception() isn't careful. stream.set_exception(RuntimeError("message")) await asyncio.sleep(0) assert stream._waiter is None async def test_readany_eof(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() read_task = loop.create_task(stream.readany()) loop.call_soon(stream.feed_data, b"chunk1\n") data = await read_task assert b"chunk1\n" == data stream.feed_eof() data = await stream.read() assert b"" == data async def test_readany_empty_eof(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() stream.feed_eof() read_task = loop.create_task(stream.readany()) data = await read_task assert b"" == data async def test_readany_exception(self) -> None: stream = self._make_one() stream.feed_data(b"line\n") data = await stream.readany() assert b"line\n" == data stream.set_exception(ValueError()) with pytest.raises(ValueError): await stream.readany() async def test_read_nowait(self) -> None: stream = self._make_one() stream.feed_data(b"line1\nline2\n") assert stream.read_nowait() == b"line1\nline2\n" assert stream.read_nowait() == b"" stream.feed_eof() data = await stream.read() assert b"" == data async def test_read_nowait_n(self) -> None: stream = self._make_one() stream.feed_data(b"line1\nline2\n") assert stream.read_nowait(4) == b"line" assert stream.read_nowait() == b"1\nline2\n" assert stream.read_nowait() == b"" stream.feed_eof() data = await stream.read() assert b"" == data async def test_read_nowait_exception(self) -> None: stream = self._make_one() stream.feed_data(b"line\n") stream.set_exception(ValueError()) with pytest.raises(ValueError): stream.read_nowait() async def test_read_nowait_waiter(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() stream.feed_data(b"line\n") stream._waiter = loop.create_future() with pytest.raises(RuntimeError): stream.read_nowait() async def test_readchunk(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() def cb() -> None: stream.feed_data(b"chunk1") stream.feed_data(b"chunk2") stream.feed_eof() loop.call_soon(cb) data, end_of_chunk = await stream.readchunk() assert b"chunk1" == data assert not end_of_chunk data, end_of_chunk = await stream.readchunk() assert b"chunk2" == data assert not end_of_chunk data, end_of_chunk = await stream.readchunk() assert b"" == data assert not end_of_chunk async def test_readchunk_wait_eof(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() async def cb() -> None: await asyncio.sleep(0.1) stream.feed_eof() t = loop.create_task(cb()) data, end_of_chunk = await stream.readchunk() assert b"" == data assert not end_of_chunk assert stream.is_eof() await t async def test_begin_and_end_chunk_receiving(self) -> None: stream = self._make_one() stream.begin_http_chunk_receiving() stream.feed_data(b"part1") stream.feed_data(b"part2") stream.end_http_chunk_receiving() data, end_of_chunk = await stream.readchunk() assert b"part1part2" == data assert end_of_chunk stream.begin_http_chunk_receiving() stream.feed_data(b"part3") data, end_of_chunk = await stream.readchunk() assert b"part3" == data assert not end_of_chunk stream.end_http_chunk_receiving() data, end_of_chunk = await stream.readchunk() assert b"" == data assert end_of_chunk stream.feed_eof() data, end_of_chunk = await stream.readchunk() assert b"" == data assert not end_of_chunk async def test_readany_chunk_end_race(self) -> None: stream = self._make_one() stream.begin_http_chunk_receiving() stream.feed_data(b"part1") data = await stream.readany() assert data == b"part1" loop = asyncio.get_event_loop() task = loop.create_task(stream.readany()) # Give a chance for task to create waiter and start waiting for it. await asyncio.sleep(0.1) assert stream._waiter is not None assert not task.done() # Just for sure. # This will trigger waiter, but without feeding any data. # The stream should re-create waiter again. stream.end_http_chunk_receiving() # Give a chance for task to resolve. # If everything is OK, previous action SHOULD NOT resolve the task. await asyncio.sleep(0.1) assert not task.done() # The actual test. stream.begin_http_chunk_receiving() # This SHOULD unblock the task actually. stream.feed_data(b"part2") stream.end_http_chunk_receiving() data = await task assert data == b"part2" async def test_end_chunk_receiving_without_begin(self) -> None: stream = self._make_one() with pytest.raises(RuntimeError): stream.end_http_chunk_receiving() async def test_readchunk_with_unread(self) -> None: # Test that stream.unread does not break controlled chunk receiving. stream = self._make_one() # Send 2 chunks stream.begin_http_chunk_receiving() stream.feed_data(b"part1") stream.end_http_chunk_receiving() stream.begin_http_chunk_receiving() stream.feed_data(b"part2") stream.end_http_chunk_receiving() # Read only one chunk data, end_of_chunk = await stream.readchunk() # Try to unread a part of the first chunk with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(b"rt1") # The end_of_chunk signal was already received for the first chunk, # so we receive up to the second one data, end_of_chunk = await stream.readchunk() assert b"rt1part2" == data assert end_of_chunk # Unread a part of the second chunk with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(b"rt2") data, end_of_chunk = await stream.readchunk() assert b"rt2" == data # end_of_chunk was already received for this chunk assert not end_of_chunk stream.feed_eof() data, end_of_chunk = await stream.readchunk() assert b"" == data assert not end_of_chunk async def test_readchunk_with_other_read_calls(self) -> None: # Test that stream.readchunk works when other read calls are made on # the stream. stream = self._make_one() stream.begin_http_chunk_receiving() stream.feed_data(b"part1") stream.end_http_chunk_receiving() stream.begin_http_chunk_receiving() stream.feed_data(b"part2") stream.end_http_chunk_receiving() stream.begin_http_chunk_receiving() stream.feed_data(b"part3") stream.end_http_chunk_receiving() data = await stream.read(7) assert b"part1pa" == data data, end_of_chunk = await stream.readchunk() assert b"rt2" == data assert end_of_chunk # Corner case between read/readchunk data = await stream.read(5) assert b"part3" == data data, end_of_chunk = await stream.readchunk() assert b"" == data assert end_of_chunk stream.feed_eof() data, end_of_chunk = await stream.readchunk() assert b"" == data assert not end_of_chunk async def test_chunksplits_memory_leak(self) -> None: # Test for memory leak on chunksplits stream = self._make_one() N = 500 # Warm-up variables stream.begin_http_chunk_receiving() stream.feed_data(b"Y" * N) stream.end_http_chunk_receiving() await stream.read(N) N = 300 before = get_memory_usage(stream) for _ in range(N): stream.begin_http_chunk_receiving() stream.feed_data(b"X") stream.end_http_chunk_receiving() await stream.read(N) after = get_memory_usage(stream) assert abs(after - before) == 0 async def test_read_empty_chunks(self) -> None: # Test that feeding empty chunks does not break stream stream = self._make_one() # Simulate empty first chunk. This is significant special case stream.begin_http_chunk_receiving() stream.end_http_chunk_receiving() stream.begin_http_chunk_receiving() stream.feed_data(b"ungzipped") stream.end_http_chunk_receiving() # Possible when compression is enabled. stream.begin_http_chunk_receiving() stream.end_http_chunk_receiving() # is also possible stream.begin_http_chunk_receiving() stream.end_http_chunk_receiving() stream.begin_http_chunk_receiving() stream.feed_data(b" data") stream.end_http_chunk_receiving() stream.feed_eof() data = await stream.read() assert data == b"ungzipped data" async def test_readchunk_separate_http_chunk_tail(self) -> None: # Test that stream.readchunk returns (b'', True) when end of # http chunk received after body loop = asyncio.get_event_loop() stream = self._make_one() stream.begin_http_chunk_receiving() stream.feed_data(b"part1") data, end_of_chunk = await stream.readchunk() assert b"part1" == data assert not end_of_chunk async def cb() -> None: await asyncio.sleep(0.1) stream.end_http_chunk_receiving() t = loop.create_task(cb()) data, end_of_chunk = await stream.readchunk() assert b"" == data assert end_of_chunk stream.begin_http_chunk_receiving() stream.feed_data(b"part2") data, end_of_chunk = await stream.readchunk() assert b"part2" == data assert not end_of_chunk stream.end_http_chunk_receiving() stream.begin_http_chunk_receiving() stream.feed_data(b"part3") stream.end_http_chunk_receiving() data, end_of_chunk = await stream.readchunk() assert b"" == data assert end_of_chunk data, end_of_chunk = await stream.readchunk() assert b"part3" == data assert end_of_chunk stream.begin_http_chunk_receiving() stream.feed_data(b"part4") data, end_of_chunk = await stream.readchunk() assert b"part4" == data assert not end_of_chunk await t async def cb2() -> None: await asyncio.sleep(0.1) stream.end_http_chunk_receiving() stream.feed_eof() t = loop.create_task(cb2()) data, end_of_chunk = await stream.readchunk() assert b"" == data assert end_of_chunk data, end_of_chunk = await stream.readchunk() assert b"" == data assert not end_of_chunk await t async def test___repr__(self) -> None: stream = self._make_one() assert "" == repr(stream) async def test___repr__nondefault_limit(self) -> None: stream = self._make_one(limit=123) assert "" == repr(stream) async def test___repr__eof(self) -> None: stream = self._make_one() stream.feed_eof() assert "" == repr(stream) async def test___repr__data(self) -> None: stream = self._make_one() stream.feed_data(b"data") assert "" == repr(stream) async def test___repr__exception(self) -> None: stream = self._make_one() exc = RuntimeError() stream.set_exception(exc) assert "" == repr(stream) async def test___repr__waiter(self) -> None: loop = asyncio.get_event_loop() stream = self._make_one() stream._waiter = loop.create_future() assert repr(stream).startswith("" == repr(stream) async def test_unread_empty(self) -> None: stream = self._make_one() stream.feed_data(b"line1") stream.feed_eof() with pytest.deprecated_call( match=r"^unread_data\(\) is deprecated and will be " r"removed in future releases \(#3260\)$", ): stream.unread_data(b"") data = await stream.read(5) assert b"line1" == data assert stream.at_eof() async def test_empty_stream_reader() -> None: s = streams.EmptyStreamReader() assert str(s) is not None assert repr(s) == "" assert s.set_exception(ValueError()) is None # type: ignore[func-returns-value] assert s.exception() is None assert s.feed_eof() is None # type: ignore[func-returns-value] assert s.feed_data(b"data") is None # type: ignore[func-returns-value] assert s.at_eof() await s.wait_eof() assert await s.read() == b"" assert await s.readline() == b"" assert await s.readany() == b"" assert await s.readchunk() == (b"", False) assert await s.readchunk() == (b"", True) with pytest.raises(asyncio.IncompleteReadError): await s.readexactly(10) assert s.read_nowait() == b"" assert s.total_bytes == 0 async def test_empty_stream_reader_iter_chunks() -> None: s = streams.EmptyStreamReader() # check that iter_chunks() does not cause infinite loop iter_chunks = s.iter_chunks() with pytest.raises(StopAsyncIteration): await iter_chunks.__anext__() @pytest.fixture async def buffer(loop: asyncio.AbstractEventLoop) -> streams.DataQueue[bytes]: return streams.DataQueue(loop) class TestDataQueue: def test_is_eof(self, buffer: streams.DataQueue[bytes]) -> None: assert not buffer.is_eof() buffer.feed_eof() assert buffer.is_eof() def test_at_eof(self, buffer: streams.DataQueue[bytes]) -> None: assert not buffer.at_eof() buffer.feed_eof() assert buffer.at_eof() buffer._buffer.append(b"foo") assert not buffer.at_eof() def test_feed_data(self, buffer: streams.DataQueue[bytes]) -> None: item = b" " buffer.feed_data(item) assert [item] == list(buffer._buffer) def test_feed_eof(self, buffer: streams.DataQueue[bytes]) -> None: buffer.feed_eof() assert buffer._eof async def test_read(self, buffer: streams.DataQueue[bytes]) -> None: loop = asyncio.get_event_loop() item = b"" def cb() -> None: buffer.feed_data(item) loop.call_soon(cb) data = await buffer.read() assert item is data async def test_read_eof(self, buffer: streams.DataQueue[bytes]) -> None: loop = asyncio.get_event_loop() def cb() -> None: buffer.feed_eof() loop.call_soon(cb) with pytest.raises(streams.EofStream): await buffer.read() async def test_read_cancelled(self, buffer: streams.DataQueue[bytes]) -> None: loop = asyncio.get_event_loop() read_task = loop.create_task(buffer.read()) await asyncio.sleep(0) waiter = buffer._waiter assert asyncio.isfuture(waiter) read_task.cancel() with pytest.raises(asyncio.CancelledError): await read_task assert waiter.cancelled() assert buffer._waiter is None buffer.feed_data(b"test") assert buffer._waiter is None async def test_read_until_eof(self, buffer: streams.DataQueue[bytes]) -> None: item = b"" buffer.feed_data(item) buffer.feed_eof() data = await buffer.read() assert data is item with pytest.raises(streams.EofStream): await buffer.read() async def test_read_exc(self, buffer: streams.DataQueue[bytes]) -> None: item = b"" buffer.feed_data(item) buffer.set_exception(ValueError) data = await buffer.read() assert item is data with pytest.raises(ValueError): await buffer.read() async def test_read_exception(self, buffer: streams.DataQueue[bytes]) -> None: buffer.set_exception(ValueError()) with pytest.raises(ValueError): await buffer.read() async def test_read_exception_with_data( self, buffer: streams.DataQueue[bytes] ) -> None: val = b"" buffer.feed_data(val) buffer.set_exception(ValueError()) assert val is (await buffer.read()) with pytest.raises(ValueError): await buffer.read() async def test_read_exception_on_wait( self, buffer: streams.DataQueue[bytes] ) -> None: loop = asyncio.get_event_loop() read_task = loop.create_task(buffer.read()) await asyncio.sleep(0) assert asyncio.isfuture(buffer._waiter) buffer.feed_eof() buffer.set_exception(ValueError()) with pytest.raises(ValueError): await read_task def test_exception(self, buffer: streams.DataQueue[bytes]) -> None: assert buffer.exception() is None exc = ValueError() buffer.set_exception(exc) assert buffer.exception() is exc async def test_exception_waiter(self, buffer: streams.DataQueue[bytes]) -> None: loop = asyncio.get_event_loop() async def set_err() -> None: buffer.set_exception(ValueError()) t1 = loop.create_task(buffer.read()) t2 = loop.create_task(set_err()) await asyncio.wait([t1, t2]) with pytest.raises(ValueError): t1.result() async def test_feed_data_waiters(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() reader.feed_data(b"1") assert list(reader._buffer) == [b"1"] assert reader._size == 1 assert reader.total_bytes == 1 assert waiter.done() assert not eof_waiter.done() assert reader._waiter is None assert reader._eof_waiter is eof_waiter # type: ignore[unreachable] async def test_feed_data_completed_waiters(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) waiter = reader._waiter = loop.create_future() waiter.set_result(1) reader.feed_data(b"1") assert reader._waiter is None async def test_feed_eof_waiters(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() reader.feed_eof() assert reader._eof assert waiter.done() assert eof_waiter.done() assert reader._waiter is None assert reader._eof_waiter is None # type: ignore[unreachable] async def test_feed_eof_cancelled(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() waiter.set_result(1) eof_waiter.set_result(1) reader.feed_eof() assert waiter.done() assert eof_waiter.done() assert reader._waiter is None assert reader._eof_waiter is None # type: ignore[unreachable] async def test_on_eof(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) on_eof = mock.Mock() reader.on_eof(on_eof) assert not on_eof.called reader.feed_eof() assert on_eof.called async def test_on_eof_empty_reader() -> None: reader = streams.EmptyStreamReader() on_eof = mock.Mock() reader.on_eof(on_eof) assert on_eof.called async def test_on_eof_exc_in_callback(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) on_eof = mock.Mock() on_eof.side_effect = ValueError reader.on_eof(on_eof) assert not on_eof.called reader.feed_eof() assert on_eof.called assert not reader._eof_callbacks # type: ignore[unreachable] async def test_on_eof_exc_in_callback_empty_stream_reader() -> None: reader = streams.EmptyStreamReader() on_eof = mock.Mock() on_eof.side_effect = ValueError reader.on_eof(on_eof) assert on_eof.called async def test_on_eof_eof_is_set(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) reader.feed_eof() on_eof = mock.Mock() reader.on_eof(on_eof) assert on_eof.called assert not reader._eof_callbacks async def test_on_eof_eof_is_set_exception(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) reader.feed_eof() on_eof = mock.Mock() on_eof.side_effect = ValueError reader.on_eof(on_eof) assert on_eof.called assert not reader._eof_callbacks async def test_set_exception(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() exc = ValueError() reader.set_exception(exc) assert waiter.exception() is exc assert eof_waiter.exception() is exc assert reader._waiter is None assert reader._eof_waiter is None # type: ignore[unreachable] async def test_set_exception_cancelled(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) waiter = reader._waiter = loop.create_future() eof_waiter = reader._eof_waiter = loop.create_future() waiter.set_result(1) eof_waiter.set_result(1) exc = ValueError() reader.set_exception(exc) assert waiter.exception() is None assert eof_waiter.exception() is None assert reader._waiter is None assert reader._eof_waiter is None # type: ignore[unreachable] async def test_set_exception_eof_callbacks(protocol: BaseProtocol) -> None: loop = asyncio.get_event_loop() reader = streams.StreamReader(protocol, 2**16, loop=loop) on_eof = mock.Mock() reader.on_eof(on_eof) reader.set_exception(ValueError()) assert not on_eof.called assert not reader._eof_callbacks async def test_stream_reader_lines() -> None: line_iter = iter(DATA.splitlines(keepends=True)) async for line in await create_stream(): assert line == next(line_iter, None) pytest.raises(StopIteration, next, line_iter) async def test_stream_reader_chunks_complete() -> None: # Tests if chunked iteration works if the chunking works out # (i.e. the data is divisible by the chunk size) chunk_iter = chunkify(DATA, 9) async for data in (await create_stream()).iter_chunked(9): assert data == next(chunk_iter, None) pytest.raises(StopIteration, next, chunk_iter) async def test_stream_reader_chunks_incomplete() -> None: # Tests if chunked iteration works if the last chunk is incomplete chunk_iter = chunkify(DATA, 8) async for data in (await create_stream()).iter_chunked(8): assert data == next(chunk_iter, None) pytest.raises(StopIteration, next, chunk_iter) async def test_data_queue_empty() -> None: # Tests that async looping yields nothing if nothing is there loop = asyncio.get_event_loop() buffer: streams.DataQueue[bytes] = streams.DataQueue(loop) buffer.feed_eof() async for _ in buffer: assert False async def test_data_queue_items() -> None: # Tests that async looping yields objects identically loop = asyncio.get_event_loop() buffer = streams.DataQueue[str](loop) items = ["a", "b"] buffer.feed_data(items[0]) buffer.feed_data(items[1]) buffer.feed_eof() item_iter = iter(items) async for item in buffer: assert item is next(item_iter, None) pytest.raises(StopIteration, next, item_iter) async def test_stream_reader_iter_any() -> None: it = iter([b"line1\nline2\nline3\n"]) async for raw in (await create_stream()).iter_any(): assert raw == next(it) pytest.raises(StopIteration, next, it) async def test_stream_reader_iter() -> None: it = iter([b"line1\n", b"line2\n", b"line3\n"]) async for raw in await create_stream(): assert raw == next(it) pytest.raises(StopIteration, next, it) async def test_stream_reader_iter_chunks_no_chunked_encoding() -> None: it = iter([b"line1\nline2\nline3\n"]) async for data, end_of_chunk in (await create_stream()).iter_chunks(): assert (data, end_of_chunk) == (next(it), False) pytest.raises(StopIteration, next, it) async def test_stream_reader_iter_chunks_chunked_encoding( protocol: BaseProtocol, ) -> None: loop = asyncio.get_event_loop() stream = streams.StreamReader(protocol, 2**16, loop=loop) for line in DATA.splitlines(keepends=True): stream.begin_http_chunk_receiving() stream.feed_data(line) stream.end_http_chunk_receiving() stream.feed_eof() it = iter([b"line1\n", b"line2\n", b"line3\n"]) async for data, end_of_chunk in stream.iter_chunks(): assert (data, end_of_chunk) == (next(it), True) pytest.raises(StopIteration, next, it) def test_isinstance_check() -> None: assert isinstance(streams.EMPTY_PAYLOAD, streams.StreamReader) async def test_stream_reader_pause_on_high_water_chunks( protocol: mock.Mock, ) -> None: """Test that reading is paused when chunk count exceeds high water mark.""" loop = asyncio.get_event_loop() # Use small limit so high_water_chunks is small: limit // 4 = 10 stream = streams.StreamReader(protocol, limit=40, loop=loop) assert stream._high_water_chunks == 10 assert stream._low_water_chunks == 5 # Feed chunks until we exceed high_water_chunks for i in range(12): stream.begin_http_chunk_receiving() stream.feed_data(b"x") # 1 byte per chunk stream.end_http_chunk_receiving() # pause_reading should have been called when chunk count exceeded 10 protocol.pause_reading.assert_called() async def test_stream_reader_resume_on_low_water_chunks( protocol: mock.Mock, ) -> None: """Test that reading resumes when chunk count drops below low water mark.""" loop = asyncio.get_event_loop() # Use small limit so high_water_chunks is small: limit // 4 = 10 stream = streams.StreamReader(protocol, limit=40, loop=loop) assert stream._high_water_chunks == 10 assert stream._low_water_chunks == 5 # Feed chunks until we exceed high_water_chunks for i in range(12): stream.begin_http_chunk_receiving() stream.feed_data(b"x") # 1 byte per chunk stream.end_http_chunk_receiving() # Simulate that reading was paused protocol._reading_paused = True protocol.pause_reading.reset_mock() # Read data to reduce both size and chunk count # Reading will consume chunks and reduce _http_chunk_splits data = await stream.read(10) assert data == b"xxxxxxxxxx" # resume_reading should have been called when both size and chunk count # dropped below their respective low water marks protocol.resume_reading.assert_called() async def test_stream_reader_no_resume_when_chunks_still_high( protocol: mock.Mock, ) -> None: """Test that reading doesn't resume if chunk count is still above low water.""" loop = asyncio.get_event_loop() # Use small limit so high_water_chunks is small: limit // 4 = 10 stream = streams.StreamReader(protocol, limit=40, loop=loop) # Feed many chunks for i in range(12): stream.begin_http_chunk_receiving() stream.feed_data(b"x") stream.end_http_chunk_receiving() # Simulate that reading was paused protocol._reading_paused = True # Read only a few bytes - chunk count will still be high data = await stream.read(2) assert data == b"xx" # resume_reading should NOT be called because chunk count is still >= low_water_chunks protocol.resume_reading.assert_not_called() async def test_stream_reader_read_non_chunked_response( protocol: mock.Mock, ) -> None: """Test that non-chunked responses work correctly (no chunk tracking).""" loop = asyncio.get_event_loop() stream = streams.StreamReader(protocol, limit=40, loop=loop) # Non-chunked: just feed data without begin/end_http_chunk_receiving stream.feed_data(b"Hello World") # _http_chunk_splits should be None for non-chunked responses assert stream._http_chunk_splits is None # Reading should work without issues data = await stream.read(5) assert data == b"Hello" data = await stream.read(6) assert data == b" World" async def test_stream_reader_resume_non_chunked_when_paused( protocol: mock.Mock, ) -> None: """Test that resume works for non-chunked responses when paused due to size.""" loop = asyncio.get_event_loop() # Small limit so we can trigger pause via size stream = streams.StreamReader(protocol, limit=10, loop=loop) # Feed data that exceeds high_water (limit * 2 = 20) stream.feed_data(b"x" * 25) # Simulate that reading was paused due to size protocol._reading_paused = True protocol.pause_reading.assert_called() # Read enough to drop below low_water (limit = 10) data = await stream.read(20) assert data == b"x" * 20 # resume_reading should be called (size is now 5 < low_water 10) protocol.resume_reading.assert_called() @pytest.mark.parametrize("limit", [1, 2, 4]) async def test_stream_reader_small_limit_resumes_reading( protocol: mock.Mock, limit: int, ) -> None: """Test that small limits still allow resume_reading to be called. Even with very small limits, high_water_chunks should be at least 3 and low_water_chunks should be at least 2, with high > low to ensure proper flow control. """ loop = asyncio.get_event_loop() stream = streams.StreamReader(protocol, limit=limit, loop=loop) # Verify minimum thresholds are enforced and high > low assert stream._high_water_chunks >= 3 assert stream._low_water_chunks >= 2 assert stream._high_water_chunks > stream._low_water_chunks # Set up pause/resume side effects def pause_reading() -> None: protocol._reading_paused = True protocol.pause_reading.side_effect = pause_reading def resume_reading() -> None: protocol._reading_paused = False protocol.resume_reading.side_effect = resume_reading # Feed 4 chunks (triggers pause at > high_water_chunks which is >= 3) for char in b"abcd": stream.begin_http_chunk_receiving() stream.feed_data(bytes([char])) stream.end_http_chunk_receiving() # Reading should now be paused assert protocol._reading_paused is True assert protocol.pause_reading.called # Read all data - should resume (chunk count drops below low_water_chunks) data = stream.read_nowait() assert data == b"abcd" assert stream._size == 0 protocol.resume_reading.assert_called() assert protocol._reading_paused is False ================================================ FILE: tests/test_tcp_helpers.py ================================================ import socket from unittest import mock import pytest from aiohttp.tcp_helpers import tcp_nodelay has_ipv6: bool = socket.has_ipv6 if has_ipv6: # The socket.has_ipv6 flag may be True if Python was built with IPv6 # support, but the target system still may not have it. # So let's ensure that we really have IPv6 support. try: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM): pass except OSError: # pragma: no cover has_ipv6 = False # nodelay def test_tcp_nodelay_exception() -> None: transport = mock.Mock() s = mock.Mock() s.setsockopt = mock.Mock() s.family = socket.AF_INET s.setsockopt.side_effect = OSError transport.get_extra_info.return_value = s tcp_nodelay(transport, True) s.setsockopt.assert_called_with(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) def test_tcp_nodelay_enable() -> None: transport = mock.Mock() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: transport.get_extra_info.return_value = s tcp_nodelay(transport, True) assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) def test_tcp_nodelay_enable_and_disable() -> None: transport = mock.Mock() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: transport.get_extra_info.return_value = s tcp_nodelay(transport, True) assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) tcp_nodelay(transport, False) assert not s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @pytest.mark.skipif(not has_ipv6, reason="IPv6 is not available") def test_tcp_nodelay_enable_ipv6() -> None: transport = mock.Mock() with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: transport.get_extra_info.return_value = s tcp_nodelay(transport, True) assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) @pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires unix sockets") def test_tcp_nodelay_enable_unix() -> None: # do not set nodelay for unix socket transport = mock.Mock() s = mock.Mock(family=socket.AF_UNIX, type=socket.SOCK_STREAM) transport.get_extra_info.return_value = s tcp_nodelay(transport, True) assert not s.setsockopt.called def test_tcp_nodelay_enable_no_socket() -> None: transport = mock.Mock() transport.get_extra_info.return_value = None tcp_nodelay(transport, True) ================================================ FILE: tests/test_test_utils.py ================================================ import asyncio import gzip import socket import sys from collections.abc import Iterator, Mapping from typing import NoReturn from unittest import mock import pytest from multidict import CIMultiDict, CIMultiDictProxy from yarl import URL import aiohttp from aiohttp import web from aiohttp.pytest_plugin import AiohttpClient from aiohttp.test_utils import ( AioHTTPTestCase, RawTestServer, TestClient, TestServer, get_port_socket, loop_context, make_mocked_request, ) if sys.version_info >= (3, 11): from typing import assert_type _TestClient = TestClient[web.Request, web.Application] _hello_world_str = "Hello, world" _hello_world_bytes = _hello_world_str.encode("utf-8") _hello_world_gz = gzip.compress(_hello_world_bytes) def _create_example_app() -> web.Application: async def hello(request: web.Request) -> web.Response: return web.Response(body=_hello_world_bytes) async def websocket_handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.TEXT await ws.send_str(msg.data + "/answer") return ws async def cookie_handler(request: web.Request) -> web.Response: resp = web.Response(body=_hello_world_bytes) resp.set_cookie("cookie", "val") return resp app = web.Application() app.router.add_route("*", "/", hello) app.router.add_route("*", "/websocket", websocket_handler) app.router.add_route("*", "/cookie", cookie_handler) return app # these exist to test the pytest scenario @pytest.fixture def loop() -> Iterator[asyncio.AbstractEventLoop]: with loop_context() as loop: yield loop @pytest.fixture def app() -> web.Application: return _create_example_app() @pytest.fixture def test_client( loop: asyncio.AbstractEventLoop, app: web.Application ) -> Iterator[_TestClient]: async def make_client() -> TestClient[web.Request, web.Application]: return TestClient(TestServer(app)) client = loop.run_until_complete(make_client()) loop.run_until_complete(client.start_server()) yield client loop.run_until_complete(client.close()) async def test_aiohttp_client_close_is_idempotent() -> None: # a test client, called multiple times, should # not attempt to close the server again. app = _create_example_app() client = TestClient(TestServer(app)) await client.close() await client.close() class TestAioHTTPTestCase(AioHTTPTestCase): async def get_application(self) -> web.Application: return _create_example_app() async def test_example_with_loop(self) -> None: request = await self.client.request("GET", "/") assert request.status == 200 text = await request.text() assert _hello_world_str == text async def test_example_without_explicit_loop(self) -> None: request = await self.client.request("GET", "/") assert request.status == 200 text = await request.text() assert _hello_world_str == text async def test_inner_example(self) -> None: async def test_get_route() -> None: resp = await self.client.request("GET", "/") assert resp.status == 200 text = await resp.text() assert _hello_world_str == text await test_get_route() def test_get_route(loop: asyncio.AbstractEventLoop, test_client: _TestClient) -> None: async def test_get_route() -> None: resp = await test_client.request("GET", "/") assert resp.status == 200 text = await resp.text() assert _hello_world_str == text loop.run_until_complete(test_get_route()) async def test_client_websocket( loop: asyncio.AbstractEventLoop, test_client: _TestClient ) -> None: resp = await test_client.ws_connect("/websocket") await resp.send_str("foo") msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.TEXT assert "foo" in msg.data await resp.send_str("close") msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.CLOSE async def test_client_cookie( loop: asyncio.AbstractEventLoop, test_client: _TestClient ) -> None: assert not test_client.session.cookie_jar await test_client.get("/cookie") cookies = list(test_client.session.cookie_jar) assert cookies[0].key == "cookie" assert cookies[0].value == "val" @pytest.mark.parametrize( "method", ["get", "post", "options", "post", "put", "patch", "delete"] ) async def test_test_client_methods( method: str, loop: asyncio.AbstractEventLoop, test_client: _TestClient ) -> None: resp = await getattr(test_client, method)("/") assert resp.status == 200 text = await resp.text() assert _hello_world_str == text async def test_test_client_head( loop: asyncio.AbstractEventLoop, test_client: _TestClient ) -> None: resp = await test_client.head("/") assert resp.status == 200 @pytest.mark.parametrize("headers", [{"token": "x"}, CIMultiDict({"token": "x"}), {}]) def test_make_mocked_request(headers: Mapping[str, str]) -> None: req = make_mocked_request("GET", "/", headers=headers) assert req.method == "GET" assert req.path == "/" assert isinstance(req, web.Request) assert isinstance(req.headers, CIMultiDictProxy) def test_make_mocked_request_sslcontext() -> None: req = make_mocked_request("GET", "/") assert req.transport is not None assert req.transport.get_extra_info("sslcontext") is None def test_make_mocked_request_unknown_extra_info() -> None: req = make_mocked_request("GET", "/") assert req.transport is not None assert req.transport.get_extra_info("unknown_extra_info") is None def test_make_mocked_request_app() -> None: app = mock.Mock() req = make_mocked_request("GET", "/", app=app) assert req.app is app def test_make_mocked_request_app_can_store_values() -> None: req = make_mocked_request("GET", "/") req.app["a_field"] = "a_value" assert req.app["a_field"] == "a_value" def test_make_mocked_request_app_access_non_existing() -> None: req = make_mocked_request("GET", "/") with pytest.raises(AttributeError): req.app.foo # type: ignore[attr-defined] def test_make_mocked_request_match_info() -> None: req = make_mocked_request("GET", "/", match_info={"a": "1", "b": "2"}) assert req.match_info == {"a": "1", "b": "2"} def test_make_mocked_request_content() -> None: payload = mock.Mock() req = make_mocked_request("GET", "/", payload=payload) assert req.content is payload async def test_make_mocked_request_empty_payload() -> None: req = make_mocked_request("GET", "/") assert await req.read() == b"" def test_make_mocked_request_transport() -> None: transport = mock.Mock() req = make_mocked_request("GET", "/", transport=transport) assert req.transport is transport async def test_test_client_props() -> None: app = _create_example_app() server = TestServer(app, scheme="http", host="127.0.0.1") client = TestClient(server) assert client.scheme == "http" assert client.host == "127.0.0.1" assert client.port == 0 async with client: assert isinstance(client.port, int) assert client.server is not None if sys.version_info >= (3, 11): assert_type(client.app, web.Application) assert client.app is not None assert client.port == 0 async def test_test_client_raw_server_props() -> None: async def hello(request: web.BaseRequest) -> NoReturn: assert False server = RawTestServer(hello, scheme="http", host="127.0.0.1") client = TestClient(server) assert client.scheme == "http" assert client.host == "127.0.0.1" assert client.port == 0 async with client: assert isinstance(client.port, int) assert client.server is not None if sys.version_info >= (3, 11): assert_type(client.app, None) assert client.app is None assert client.port == 0 async def test_test_server_context_manager(loop: asyncio.AbstractEventLoop) -> None: app = _create_example_app() async with TestServer(app) as server: client = aiohttp.ClientSession() resp = await client.head(server.make_url("/")) assert resp.status == 200 resp.close() await client.close() def test_client_unsupported_arg() -> None: with pytest.raises(TypeError) as e: TestClient("string") # type: ignore[call-overload] assert ( str(e.value) == "server must be TestServer instance, found type: " ) async def test_server_make_url_yarl_compatibility( loop: asyncio.AbstractEventLoop, ) -> None: app = _create_example_app() async with TestServer(app) as server: make_url = server.make_url assert make_url(URL("/foo")) == make_url("/foo") with pytest.raises(AssertionError): make_url("http://foo.com") with pytest.raises(AssertionError): make_url(URL("http://foo.com")) @pytest.mark.xfail(reason="https://github.com/pytest-dev/pytest/issues/13546") def test_testcase_no_app( testdir: pytest.Testdir, loop: asyncio.AbstractEventLoop ) -> None: testdir.makepyfile(""" from aiohttp.test_utils import AioHTTPTestCase class InvalidTestCase(AioHTTPTestCase): def test_noop(self) -> None: pass """) result = testdir.runpytest() result.stdout.fnmatch_lines(["*TypeError*"]) async def test_disable_retry_persistent_connection( aiohttp_client: AiohttpClient, ) -> None: num_requests = 0 async def handler(request: web.Request) -> web.Response: nonlocal num_requests num_requests += 1 request.protocol.force_close() return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) with pytest.raises(aiohttp.ServerDisconnectedError): await client.get("/") assert num_requests == 1 async def test_server_context_manager( app: web.Application, loop: asyncio.AbstractEventLoop ) -> None: async with TestServer(app) as server: async with aiohttp.ClientSession() as client: async with client.head(server.make_url("/")) as resp: assert resp.status == 200 @pytest.mark.parametrize( "method", ["head", "get", "post", "options", "post", "put", "patch", "delete"] ) async def test_client_context_manager_response( method: str, app: web.Application, loop: asyncio.AbstractEventLoop ) -> None: async with TestClient(TestServer(app)) as client: async with getattr(client, method)("/") as resp: assert resp.status == 200 if method != "head": text = await resp.text() assert "Hello, world" in text async def test_custom_port( loop: asyncio.AbstractEventLoop, app: web.Application, unused_port_socket: socket.socket, ) -> None: sock = unused_port_socket port = sock.getsockname()[1] client = TestClient( TestServer(app, port=port, socket_factory=lambda *args, **kwargs: sock) ) await client.start_server() assert client.server.port == port resp = await client.get("/") assert resp.status == 200 text = await resp.text() assert _hello_world_str == text await client.close() @pytest.mark.parametrize( ("hostname", "expected_host"), [("127.0.0.1", "127.0.0.1"), ("localhost", "127.0.0.1"), ("::1", "::1")], ) async def test_test_server_hostnames( hostname: str, expected_host: str, loop: asyncio.AbstractEventLoop ) -> None: app = _create_example_app() server = TestServer(app, host=hostname, loop=loop) async with server: pass assert server.host == expected_host @pytest.mark.parametrize("test_server_cls", [TestServer, RawTestServer]) async def test_base_test_server_socket_factory( test_server_cls: type, app: web.Application, loop: asyncio.AbstractEventLoop ) -> None: factory_called = False def factory(host: str, port: int, family: socket.AddressFamily) -> socket.socket: nonlocal factory_called factory_called = True return get_port_socket(host, port, family) server = test_server_cls(app, loop=loop, socket_factory=factory) async with server: pass assert factory_called ================================================ FILE: tests/test_tracing.py ================================================ import sys from types import SimpleNamespace from typing import Any from unittest import mock from unittest.mock import Mock import pytest from aiosignal import Signal from aiohttp import ClientSession from aiohttp.tracing import ( Trace, TraceConfig, TraceConnectionCreateEndParams, TraceConnectionCreateStartParams, TraceConnectionQueuedEndParams, TraceConnectionQueuedStartParams, TraceConnectionReuseconnParams, TraceDnsCacheHitParams, TraceDnsCacheMissParams, TraceDnsResolveHostEndParams, TraceDnsResolveHostStartParams, TraceRequestChunkSentParams, TraceRequestEndParams, TraceRequestExceptionParams, TraceRequestRedirectParams, TraceRequestStartParams, TraceResponseChunkReceivedParams, ) if sys.version_info >= (3, 11): from typing import assert_type class TestTraceConfig: def test_trace_config_ctx_default(self) -> None: trace_config = TraceConfig() assert isinstance(trace_config.trace_config_ctx(), SimpleNamespace) if sys.version_info >= (3, 11): assert_type( trace_config.on_request_chunk_sent, Signal[ClientSession, SimpleNamespace, TraceRequestChunkSentParams], ) def test_trace_config_ctx_factory(self) -> None: trace_config = TraceConfig(trace_config_ctx_factory=dict) assert isinstance(trace_config.trace_config_ctx(), dict) if sys.version_info >= (3, 11): assert_type( trace_config.on_request_start, Signal[ClientSession, dict[str, Any], TraceRequestStartParams], ) def test_trace_config_ctx_request_ctx(self) -> None: trace_request_ctx = Mock() trace_config = TraceConfig() trace_config_ctx = trace_config.trace_config_ctx( trace_request_ctx=trace_request_ctx ) assert trace_config_ctx.trace_request_ctx is trace_request_ctx def test_trace_config_ctx_custom_class(self) -> None: """Custom class instances should be accepted as trace_request_ctx (#10753).""" class MyContext: def __init__(self, request_id: int) -> None: self.request_id = request_id ctx = MyContext(request_id=42) trace_config = TraceConfig() trace_config_ctx = trace_config.trace_config_ctx(trace_request_ctx=ctx) assert trace_config_ctx.trace_request_ctx is ctx assert trace_config_ctx.trace_request_ctx.request_id == 42 def test_freeze(self) -> None: trace_config = TraceConfig() trace_config.freeze() assert trace_config.on_request_start.frozen assert trace_config.on_request_chunk_sent.frozen assert trace_config.on_response_chunk_received.frozen assert trace_config.on_request_end.frozen assert trace_config.on_request_exception.frozen assert trace_config.on_request_redirect.frozen assert trace_config.on_connection_queued_start.frozen assert trace_config.on_connection_queued_end.frozen assert trace_config.on_connection_create_start.frozen assert trace_config.on_connection_create_end.frozen assert trace_config.on_connection_reuseconn.frozen assert trace_config.on_dns_resolvehost_start.frozen assert trace_config.on_dns_resolvehost_end.frozen assert trace_config.on_dns_cache_hit.frozen assert trace_config.on_dns_cache_miss.frozen assert trace_config.on_request_headers_sent.frozen class TestTrace: @pytest.mark.parametrize( "signal,params,param_obj", [ ("request_start", (Mock(), Mock(), Mock()), TraceRequestStartParams), ( "request_chunk_sent", (Mock(), Mock(), Mock()), TraceRequestChunkSentParams, ), ( "response_chunk_received", (Mock(), Mock(), Mock()), TraceResponseChunkReceivedParams, ), ("request_end", (Mock(), Mock(), Mock(), Mock()), TraceRequestEndParams), ( "request_exception", (Mock(), Mock(), Mock(), Mock()), TraceRequestExceptionParams, ), ( "request_redirect", (Mock(), Mock(), Mock(), Mock()), TraceRequestRedirectParams, ), ("connection_queued_start", (), TraceConnectionQueuedStartParams), ("connection_queued_end", (), TraceConnectionQueuedEndParams), ("connection_create_start", (), TraceConnectionCreateStartParams), ("connection_create_end", (), TraceConnectionCreateEndParams), ("connection_reuseconn", (), TraceConnectionReuseconnParams), ("dns_resolvehost_start", (Mock(),), TraceDnsResolveHostStartParams), ("dns_resolvehost_end", (Mock(),), TraceDnsResolveHostEndParams), ("dns_cache_hit", (Mock(),), TraceDnsCacheHitParams), ("dns_cache_miss", (Mock(),), TraceDnsCacheMissParams), ], ) async def test_send( # type: ignore[misc] self, signal: str, params: tuple[Mock, ...], param_obj: Any ) -> None: session = Mock() trace_request_ctx = Mock() callback = mock.AsyncMock() trace_config = TraceConfig() getattr(trace_config, "on_%s" % signal).append(callback) trace_config.freeze() trace = Trace( session, trace_config, trace_config.trace_config_ctx(trace_request_ctx=trace_request_ctx), ) await getattr(trace, "send_%s" % signal)(*params) callback.assert_called_once_with( session, SimpleNamespace(trace_request_ctx=trace_request_ctx), param_obj(*params), ) ================================================ FILE: tests/test_urldispatch.py ================================================ import asyncio import pathlib import platform import re from collections.abc import ( Awaitable, Callable, Container, Iterable, Mapping, MutableMapping, Sized, ) from functools import partial from typing import NoReturn from urllib.parse import quote, unquote import pytest from yarl import URL import aiohttp from aiohttp import hdrs, web from aiohttp.test_utils import make_mocked_request from aiohttp.web_urldispatcher import ( PATH_SEP, Domain, MaskDomain, SystemRoute, _default_expect_handler, ) def make_handler() -> Callable[[web.Request], Awaitable[NoReturn]]: async def handler(request: web.Request) -> NoReturn: assert False return handler def make_partial_handler() -> Callable[[web.Request], Awaitable[NoReturn]]: async def handler(a: int, request: web.Request) -> NoReturn: assert False return partial(handler, 5) @pytest.fixture def app() -> web.Application: return web.Application() @pytest.fixture def router(app: web.Application) -> web.UrlDispatcher: return app.router @pytest.fixture def fill_routes(router: web.UrlDispatcher) -> Callable[[], list[web.AbstractRoute]]: def go() -> list[web.AbstractRoute]: route1 = router.add_route("GET", "/plain", make_handler()) route2 = router.add_route("GET", "/variable/{name}", make_handler()) resource = router.add_static("/static", pathlib.Path(aiohttp.__file__).parent) return [route1, route2] + list(resource) return go def test_register_uncommon_http_methods(router: web.UrlDispatcher) -> None: uncommon_http_methods = { "PROPFIND", "PROPPATCH", "COPY", "LOCK", "UNLOCK", "MOVE", "SUBSCRIBE", "UNSUBSCRIBE", "NOTIFY", } for method in uncommon_http_methods: router.add_route(method, "/handler/to/path", make_handler()) async def test_add_partial_handler(router: web.UrlDispatcher) -> None: handler = make_partial_handler() router.add_get("/handler/to/path", handler) async def test_add_sync_handler(router: web.UrlDispatcher) -> None: def handler(request: web.Request) -> NoReturn: assert False with pytest.raises(TypeError): router.add_get("/handler/to/path", handler) async def test_add_route_root(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/", handler) req = make_mocked_request("GET", "/") info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None async def test_add_route_simple(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/handler/to/path", handler) req = make_mocked_request("GET", "/handler/to/path") info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None async def test_add_with_matchdict(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/handler/{to}", handler) req = make_mocked_request("GET", "/handler/tail") info = await router.resolve(req) assert info is not None assert {"to": "tail"} == info assert handler is info.handler assert info.route.name is None async def test_add_with_matchdict_with_colon(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/handler/{to}", handler) req = make_mocked_request("GET", "/handler/1:2:3") info = await router.resolve(req) assert info is not None assert {"to": "1:2:3"} == info assert handler is info.handler assert info.route.name is None async def test_add_route_with_add_get_shortcut(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_get("/handler/to/path", handler) req = make_mocked_request("GET", "/handler/to/path") info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None async def test_add_route_with_add_post_shortcut(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_post("/handler/to/path", handler) req = make_mocked_request("POST", "/handler/to/path") info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None async def test_add_route_with_add_put_shortcut(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_put("/handler/to/path", handler) req = make_mocked_request("PUT", "/handler/to/path") info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None async def test_add_route_with_add_patch_shortcut(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_patch("/handler/to/path", handler) req = make_mocked_request("PATCH", "/handler/to/path") info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None async def test_add_route_with_add_delete_shortcut(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_delete("/handler/to/path", handler) req = make_mocked_request("DELETE", "/handler/to/path") info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None async def test_add_route_with_add_head_shortcut(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_head("/handler/to/path", handler) req = make_mocked_request("HEAD", "/handler/to/path") info = await router.resolve(req) assert info is not None assert 0 == len(info) assert handler is info.handler assert info.route.name is None async def test_add_with_name(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/handler/to/path", handler, name="name") req = make_mocked_request("GET", "/handler/to/path") info = await router.resolve(req) assert info is not None assert "name" == info.route.name async def test_add_with_tailing_slash(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/handler/to/path/", handler) req = make_mocked_request("GET", "/handler/to/path/") info = await router.resolve(req) assert info is not None assert {} == info assert handler is info.handler def test_add_invalid_path(router: web.UrlDispatcher) -> None: handler = make_handler() with pytest.raises(ValueError): router.add_route("GET", "/{/", handler) def test_add_url_invalid1(router: web.UrlDispatcher) -> None: handler = make_handler() with pytest.raises(ValueError): router.add_route("post", "/post/{id", handler) def test_add_url_invalid2(router: web.UrlDispatcher) -> None: handler = make_handler() with pytest.raises(ValueError): router.add_route("post", "/post/{id{}}", handler) def test_add_url_invalid3(router: web.UrlDispatcher) -> None: handler = make_handler() with pytest.raises(ValueError): router.add_route("post", "/post/{id{}", handler) def test_add_url_invalid4(router: web.UrlDispatcher) -> None: handler = make_handler() with pytest.raises(ValueError): router.add_route("post", '/post/{id"}', handler) async def test_add_url_escaping(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/+$", handler) req = make_mocked_request("GET", "/+$") info = await router.resolve(req) assert info is not None assert handler is info.handler async def test_any_method(router: web.UrlDispatcher) -> None: handler = make_handler() route = router.add_route(hdrs.METH_ANY, "/", handler) req = make_mocked_request("GET", "/") info1 = await router.resolve(req) assert info1 is not None assert route is info1.route req = make_mocked_request("POST", "/") info2 = await router.resolve(req) assert info2 is not None assert info1.route is info2.route async def test_any_method_appears_in_routes(router: web.UrlDispatcher) -> None: handler = make_handler() route = router.add_route(hdrs.METH_ANY, "/", handler) assert route in router.routes() async def test_match_second_result_in_table(router: web.UrlDispatcher) -> None: handler1 = make_handler() handler2 = make_handler() router.add_route("GET", "/h1", handler1) router.add_route("POST", "/h2", handler2) req = make_mocked_request("POST", "/h2") info = await router.resolve(req) assert info is not None assert {} == info assert handler2 is info.handler async def test_raise_method_not_allowed(router: web.UrlDispatcher) -> None: handler1 = make_handler() handler2 = make_handler() router.add_route("GET", "/", handler1) router.add_route("POST", "/", handler2) req = make_mocked_request("PUT", "/") match_info = await router.resolve(req) assert isinstance(match_info.route, SystemRoute) assert {} == match_info with pytest.raises(web.HTTPMethodNotAllowed) as ctx: await match_info.handler(req) exc = ctx.value assert "PUT" == exc.method assert 405 == exc.status assert {"POST", "GET"} == exc.allowed_methods async def test_raise_method_not_found(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/a", handler) req = make_mocked_request("GET", "/b") match_info = await router.resolve(req) assert isinstance(match_info.route, SystemRoute) assert {} == match_info with pytest.raises(web.HTTPNotFound) as ctx: await match_info.handler(req) exc = ctx.value assert 404 == exc.status def test_double_add_url_with_the_same_name(router: web.UrlDispatcher) -> None: handler1 = make_handler() handler2 = make_handler() router.add_route("GET", "/get", handler1, name="name") with pytest.raises(ValueError) as ctx: router.add_route("GET", "/get_other", handler2, name="name") assert str(ctx.value).startswith("Duplicate 'name', already handled by") def test_route_plain(router: web.UrlDispatcher) -> None: handler = make_handler() route = router.add_route("GET", "/get", handler, name="name") route2 = next(iter(router["name"])) url = route2.url_for() assert "/get" == str(url) assert route is route2 def test_route_unknown_route_name(router: web.UrlDispatcher) -> None: with pytest.raises(KeyError): router["unknown"] def test_route_dynamic(router: web.UrlDispatcher) -> None: handler = make_handler() route = router.add_route("GET", "/get/{name}", handler, name="name") route2 = next(iter(router["name"])) url = route2.url_for(name="John") assert "/get/John" == str(url) assert route is route2 def test_add_static_path_checks( router: web.UrlDispatcher, tmp_path: pathlib.Path ) -> None: """Test that static paths must exist and be directories.""" with pytest.raises(ValueError, match="does not exist"): router.add_static("/", tmp_path / "does-not-exist") with pytest.raises(ValueError, match="is not a directory"): router.add_static("/", __file__) def test_add_static_path_resolution(router: web.UrlDispatcher) -> None: """Test that static paths are expanded and absolute.""" res = router.add_static("/", "~/..") directory = str(res.get_info()["directory"]) assert directory == str(pathlib.Path.home().resolve(strict=True).parent) def test_add_static(router: web.UrlDispatcher) -> None: resource = router.add_static( "/st", pathlib.Path(aiohttp.__file__).parent, name="static" ) assert router["static"] is resource url = resource.url_for(filename="/dir/a.txt") assert "/st/dir/a.txt" == str(url) assert len(resource) == 2 def test_add_static_append_version(router: web.UrlDispatcher) -> None: resource = router.add_static("/st", pathlib.Path(__file__).parent, name="static") url = resource.url_for(filename="/data.unknown_mime_type", append_version=True) expect_url = ( "/st/data.unknown_mime_type?v=aUsn8CHEhhszc81d28QmlcBW0KQpfS2F4trgQKhOYd8%3D" ) assert expect_url == str(url) def test_add_static_append_version_set_from_constructor( router: web.UrlDispatcher, ) -> None: resource = router.add_static( "/st", pathlib.Path(__file__).parent, append_version=True, name="static" ) url = resource.url_for(filename="/data.unknown_mime_type") expect_url = ( "/st/data.unknown_mime_type?v=aUsn8CHEhhszc81d28QmlcBW0KQpfS2F4trgQKhOYd8%3D" ) assert expect_url == str(url) def test_add_static_append_version_override_constructor( router: web.UrlDispatcher, ) -> None: resource = router.add_static( "/st", pathlib.Path(__file__).parent, append_version=True, name="static" ) url = resource.url_for(filename="/data.unknown_mime_type", append_version=False) expect_url = "/st/data.unknown_mime_type" assert expect_url == str(url) def test_add_static_append_version_filename_without_slash( router: web.UrlDispatcher, ) -> None: resource = router.add_static("/st", pathlib.Path(__file__).parent, name="static") url = resource.url_for(filename="data.unknown_mime_type", append_version=True) expect_url = ( "/st/data.unknown_mime_type?v=aUsn8CHEhhszc81d28QmlcBW0KQpfS2F4trgQKhOYd8%3D" ) assert expect_url == str(url) def test_add_static_append_version_non_exists_file(router: web.UrlDispatcher) -> None: resource = router.add_static("/st", pathlib.Path(__file__).parent, name="static") url = resource.url_for(filename="/non_exists_file", append_version=True) assert "/st/non_exists_file" == str(url) def test_add_static_append_version_non_exists_file_without_slash( router: web.UrlDispatcher, ) -> None: resource = router.add_static("/st", pathlib.Path(__file__).parent, name="static") url = resource.url_for(filename="non_exists_file", append_version=True) assert "/st/non_exists_file" == str(url) def test_add_static_append_version_follow_symlink( router: web.UrlDispatcher, tmp_path: pathlib.Path ) -> None: # Tests the access to a symlink, in static folder with apeend_version symlink_path = tmp_path / "append_version_symlink" symlink_target_path = pathlib.Path(__file__).parent pathlib.Path(str(symlink_path)).symlink_to(str(symlink_target_path), True) # Register global static route: resource = router.add_static( "/st", str(tmp_path), follow_symlinks=True, append_version=True ) url = resource.url_for(filename="/append_version_symlink/data.unknown_mime_type") expect_url = ( "/st/append_version_symlink/data.unknown_mime_type?" "v=aUsn8CHEhhszc81d28QmlcBW0KQpfS2F4trgQKhOYd8%3D" ) assert expect_url == str(url) def test_add_static_append_version_not_follow_symlink( router: web.UrlDispatcher, tmp_path: pathlib.Path ) -> None: # Tests the access to a symlink, in static folder with apeend_version symlink_path = tmp_path / "append_version_symlink" symlink_target_path = pathlib.Path(__file__).parent pathlib.Path(str(symlink_path)).symlink_to(str(symlink_target_path), True) # Register global static route: resource = router.add_static( "/st", str(tmp_path), follow_symlinks=False, append_version=True ) filename = "/append_version_symlink/data.unknown_mime_type" url = resource.url_for(filename=filename) assert "/st/append_version_symlink/data.unknown_mime_type" == str(url) def test_add_static_quoting(router: web.UrlDispatcher) -> None: resource = router.add_static( "/пре %2Fфикс", pathlib.Path(aiohttp.__file__).parent, name="static" ) assert router["static"] is resource url = resource.url_for(filename="/1 2/файл%2F.txt") assert url.path == "/пре /фикс/1 2/файл%2F.txt" assert str(url) == ( "/%D0%BF%D1%80%D0%B5%20%2F%D1%84%D0%B8%D0%BA%D1%81" "/1%202/%D1%84%D0%B0%D0%B9%D0%BB%252F.txt" ) assert len(resource) == 2 def test_plain_not_match(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/get/path", handler, name="name") route = router["name"] assert isinstance(route, web.Resource) assert route._match("/another/path") is None def test_dynamic_not_match(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/get/{name}", handler, name="name") route = router["name"] assert isinstance(route, web.Resource) assert route._match("/another/path") is None async def test_static_not_match(router: web.UrlDispatcher) -> None: router.add_static("/pre", pathlib.Path(aiohttp.__file__).parent, name="name") resource = router["name"] ret = await resource.resolve(make_mocked_request("GET", "/another/path")) assert (None, set()) == ret async def test_add_static_access_resources(router: web.UrlDispatcher) -> None: """Test accessing resource._routes externally. aiohttp-cors accesses the resource._routes, this test ensures that this continues to work. """ # https://github.com/aio-libs/aiohttp-cors/blob/38c6c17bffc805e46baccd7be1b4fd8c69d95dc3/aiohttp_cors/urldispatcher_router_adapter.py#L187 resource = router.add_static( "/st", pathlib.Path(aiohttp.__file__).parent, name="static" ) resource._routes[hdrs.METH_OPTIONS] = resource._routes[hdrs.METH_GET] resource._allowed_methods.add(hdrs.METH_OPTIONS) mapping, allowed_methods = await resource.resolve( make_mocked_request("OPTIONS", "/st/path") ) assert mapping is not None assert allowed_methods == {hdrs.METH_GET, hdrs.METH_OPTIONS, hdrs.METH_HEAD} async def test_add_static_set_options_route(router: web.UrlDispatcher) -> None: """Ensure set_options_route works as expected.""" resource = router.add_static( "/st", pathlib.Path(aiohttp.__file__).parent, name="static" ) async def handler(request: web.Request) -> NoReturn: assert False resource.set_options_route(handler) mapping, allowed_methods = await resource.resolve( make_mocked_request("OPTIONS", "/st/path") ) assert mapping is not None assert allowed_methods == {hdrs.METH_GET, hdrs.METH_OPTIONS, hdrs.METH_HEAD} def test_dynamic_with_trailing_slash(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/get/{name}/", handler, name="name") route = router["name"] assert isinstance(route, web.Resource) assert {"name": "John"} == route._match("/get/John/") def test_len(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/get1", handler, name="name1") router.add_route("GET", "/get2", handler, name="name2") assert 2 == len(router) def test_iter(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/get1", handler, name="name1") router.add_route("GET", "/get2", handler, name="name2") assert {"name1", "name2"} == set(iter(router)) def test_contains(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/get1", handler, name="name1") router.add_route("GET", "/get2", handler, name="name2") assert "name1" in router assert "name3" not in router def test_static_repr(router: web.UrlDispatcher) -> None: router.add_static("/get", pathlib.Path(aiohttp.__file__).parent, name="name") assert repr(router["name"]).startswith(" None: route = router.add_static("/prefix", pathlib.Path(aiohttp.__file__).parent) assert "/prefix" == route._prefix def test_static_remove_trailing_slash(router: web.UrlDispatcher) -> None: route = router.add_static("/prefix/", pathlib.Path(aiohttp.__file__).parent) assert "/prefix" == route._prefix @pytest.mark.parametrize( "pattern,url,expected", ( (r"{to:\d+}", r"1234", {"to": "1234"}), ("{name}.html", "test.html", {"name": "test"}), (r"{fn:\w+ \d+}", "abc 123", {"fn": "abc 123"}), (r"{fn:\w+\s\d+}", "abc 123", {"fn": "abc 123"}), ), ) async def test_add_route_with_re( router: web.UrlDispatcher, pattern: str, url: str, expected: dict[str, str] ) -> None: handler = make_handler() router.add_route("GET", f"/handler/{pattern}", handler) req = make_mocked_request("GET", f"/handler/{url}") info = await router.resolve(req) assert info is not None assert info == expected async def test_add_route_with_re_and_slashes(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", r"/handler/{to:[^/]+/?}", handler) req = make_mocked_request("GET", "/handler/1234/") info = await router.resolve(req) assert info is not None assert {"to": "1234/"} == info router.add_route("GET", r"/handler/{to:.+}", handler) req = make_mocked_request("GET", "/handler/1234/5/6/7") info = await router.resolve(req) assert info is not None assert {"to": "1234/5/6/7"} == info async def test_add_route_with_re_not_match(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", r"/handler/{to:\d+}", handler) req = make_mocked_request("GET", "/handler/tail") match_info = await router.resolve(req) assert isinstance(match_info.route, SystemRoute) assert {} == match_info with pytest.raises(web.HTTPNotFound): await match_info.handler(req) async def test_add_route_with_re_including_slashes(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", r"/handler/{to:.+}/tail", handler) req = make_mocked_request("GET", "/handler/re/with/slashes/tail") info = await router.resolve(req) assert info is not None assert {"to": "re/with/slashes"} == info def test_add_route_with_invalid_re(router: web.UrlDispatcher) -> None: handler = make_handler() with pytest.raises(ValueError) as ctx: router.add_route("GET", r"/handler/{to:+++}", handler) s = str(ctx.value) assert s.startswith( "Bad pattern '" + PATH_SEP + "handler" + PATH_SEP + "(?P+++)': nothing to repeat" ) assert ctx.value.__cause__ is None def test_route_dynamic_with_regex_spec(router: web.UrlDispatcher) -> None: handler = make_handler() route = router.add_route("GET", r"/get/{num:^\d+}", handler, name="name") url = route.url_for(num="123") assert "/get/123" == str(url) def test_route_dynamic_with_regex_spec_and_trailing_slash( router: web.UrlDispatcher, ) -> None: handler = make_handler() route = router.add_route("GET", r"/get/{num:^\d+}/", handler, name="name") url = route.url_for(num="123") assert "/get/123/" == str(url) def test_route_dynamic_with_regex(router: web.UrlDispatcher) -> None: handler = make_handler() route = router.add_route("GET", r"/{one}/{two:.+}", handler) url = route.url_for(one="1", two="2") assert "/1/2" == str(url) def test_route_dynamic_quoting(router: web.UrlDispatcher) -> None: handler = make_handler() route = router.add_route("GET", r"/пре %2Fфикс/{arg}", handler) url = route.url_for(arg="1 2/текст%2F") assert url.path == "/пре /фикс/1 2/текст%2F" assert str(url) == ( "/%D0%BF%D1%80%D0%B5%20%2F%D1%84%D0%B8%D0%BA%D1%81" "/1%202/%D1%82%D0%B5%D0%BA%D1%81%D1%82%252F" ) async def test_regular_match_info(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/get/{name}", handler) req = make_mocked_request("GET", "/get/john") match_info = await router.resolve(req) assert {"name": "john"} == match_info assert repr(match_info).startswith(" None: handler = make_handler() router.add_route("GET", "/get/{version}", handler) req = make_mocked_request("GET", "/get/1.0+test") match_info = await router.resolve(req) assert {"version": "1.0+test"} == match_info async def test_not_found_repr(router: web.UrlDispatcher) -> None: req = make_mocked_request("POST", "/path/to") match_info = await router.resolve(req) assert "" == repr(match_info) async def test_not_allowed_repr(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/path/to", handler) handler2 = make_handler() router.add_route("POST", "/path/to", handler2) req = make_mocked_request("PUT", "/path/to") match_info = await router.resolve(req) assert "" == repr(match_info) def test_default_expect_handler(router: web.UrlDispatcher) -> None: route = router.add_route("GET", "/", make_handler()) assert route._expect_handler is _default_expect_handler def test_custom_expect_handler_plain(router: web.UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False route = router.add_route("GET", "/", make_handler(), expect_handler=handler) assert route._expect_handler is handler assert isinstance(route, web.ResourceRoute) def test_custom_expect_handler_dynamic(router: web.UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False route = router.add_route( "GET", "/get/{name}", make_handler(), expect_handler=handler ) assert route._expect_handler is handler assert isinstance(route, web.ResourceRoute) def test_expect_handler_non_coroutine(router: web.UrlDispatcher) -> None: def handler(request: web.Request) -> NoReturn: assert False with pytest.raises(AssertionError): router.add_route("GET", "/", make_handler(), expect_handler=handler) async def test_dynamic_match_non_ascii(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/{var}", handler) req = make_mocked_request( "GET", "/%D1%80%D1%83%D1%81%20%D1%82%D0%B5%D0%BA%D1%81%D1%82" ) match_info = await router.resolve(req) assert {"var": "рус текст"} == match_info async def test_dynamic_match_with_static_part(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/{name}.html", handler) req = make_mocked_request("GET", "/file.html") match_info = await router.resolve(req) assert {"name": "file"} == match_info async def test_dynamic_match_two_part2(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/{name}.{ext}", handler) req = make_mocked_request("GET", "/file.html") match_info = await router.resolve(req) assert {"name": "file", "ext": "html"} == match_info async def test_dynamic_match_unquoted_path(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/{path}/{subpath}", handler) resource_id = "my%2Fpath%7Cwith%21some%25strange%24characters" req = make_mocked_request("GET", f"/path/{resource_id}") match_info = await router.resolve(req) assert match_info == {"path": "path", "subpath": unquote(resource_id)} async def test_dynamic_match_double_quoted_path(router: web.UrlDispatcher) -> None: """Verify that double-quoted path is unquoted only once.""" handler = make_handler() router.add_route("GET", "/{path}/{subpath}", handler) resource_id = quote("my/path|with!some%strange$characters", safe="") double_quoted_resource_id = quote(resource_id, safe="") req = make_mocked_request("GET", f"/path/{double_quoted_resource_id}") match_info = await router.resolve(req) assert match_info == {"path": "path", "subpath": resource_id} def test_add_route_not_started_with_slash(router: web.UrlDispatcher) -> None: with pytest.raises(ValueError): handler = make_handler() router.add_route("GET", "invalid_path", handler) def test_add_route_invalid_method(router: web.UrlDispatcher) -> None: sample_bad_methods = { "BAD METHOD", "B@D_METHOD", "[BAD_METHOD]", "{BAD_METHOD}", "(BAD_METHOD)", "B?D_METHOD", } for bad_method in sample_bad_methods: with pytest.raises(ValueError): handler = make_handler() router.add_route(bad_method, "/path", handler) def test_routes_view_len( router: web.UrlDispatcher, fill_routes: Callable[[], list[web.AbstractRoute]] ) -> None: fill_routes() assert 4 == len(router.routes()) def test_routes_view_iter( router: web.UrlDispatcher, fill_routes: Callable[[], list[web.AbstractRoute]] ) -> None: routes = fill_routes() assert list(routes) == list(router.routes()) def test_routes_view_contains( router: web.UrlDispatcher, fill_routes: Callable[[], list[web.AbstractRoute]] ) -> None: routes = fill_routes() for route in routes: assert route in router.routes() def test_routes_abc(router: web.UrlDispatcher) -> None: assert isinstance(router.routes(), Sized) assert isinstance(router.routes(), Iterable) assert isinstance(router.routes(), Container) def test_named_resources_abc(router: web.UrlDispatcher) -> None: assert isinstance(router.named_resources(), Mapping) assert not isinstance(router.named_resources(), MutableMapping) def test_named_resources(router: web.UrlDispatcher) -> None: route1 = router.add_route("GET", "/plain", make_handler(), name="route1") route2 = router.add_route("GET", "/variable/{name}", make_handler(), name="route2") route3 = router.add_static( "/static", pathlib.Path(aiohttp.__file__).parent, name="route3" ) names = {route1.name, route2.name, route3.name} assert 3 == len(router.named_resources()) for name in names: assert name is not None assert name in router.named_resources() assert isinstance(router.named_resources()[name], web.AbstractResource) def test_resource_iter(router: web.UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False resource = router.add_resource("/path") r1 = resource.add_route("GET", handler) r2 = resource.add_route("POST", handler) assert 2 == len(resource) assert [r1, r2] == list(resource) def test_view_route(router: web.UrlDispatcher) -> None: resource = router.add_resource("/path") route = resource.add_route("*", web.View) assert web.View is route.handler def test_resource_route_match(router: web.UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False resource = router.add_resource("/path") route = resource.add_route("GET", handler) assert isinstance(route.resource, web.Resource) assert {} == route.resource._match("/path") def test_error_on_double_route_adding(router: web.UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False resource = router.add_resource("/path") resource.add_route("GET", handler) with pytest.raises(RuntimeError): resource.add_route("GET", handler) def test_error_on_adding_route_after_wildcard(router: web.UrlDispatcher) -> None: async def handler(request: web.Request) -> NoReturn: assert False resource = router.add_resource("/path") resource.add_route("*", handler) with pytest.raises(RuntimeError): resource.add_route("GET", handler) async def test_http_exception_is_none_when_resolved(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/", handler) req = make_mocked_request("GET", "/") info = await router.resolve(req) assert info.http_exception is None async def test_http_exception_is_not_none_when_not_resolved( router: web.UrlDispatcher, ) -> None: handler = make_handler() router.add_route("GET", "/", handler) req = make_mocked_request("GET", "/abc") info = await router.resolve(req) assert info.http_exception is not None assert info.http_exception.status == 404 async def test_match_info_get_info_plain(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/", handler) req = make_mocked_request("GET", "/") info = await router.resolve(req) assert info.get_info() == {"path": "/"} async def test_match_info_get_info_dynamic(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/{a}", handler) req = make_mocked_request("GET", "/value") info = await router.resolve(req) assert info.get_info() == { "pattern": re.compile(PATH_SEP + "(?P[^{}/]+)"), "formatter": "/{a}", } async def test_match_info_get_info_dynamic2(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/{a}/{b}", handler) req = make_mocked_request("GET", "/path/to") info = await router.resolve(req) assert info.get_info() == { "pattern": re.compile( PATH_SEP + "(?P[^{}/]+)" + PATH_SEP + "(?P[^{}/]+)" ), "formatter": "/{a}/{b}", } def test_static_resource_get_info(router: web.UrlDispatcher) -> None: directory = pathlib.Path(aiohttp.__file__).parent.resolve() resource = router.add_static("/st", directory) info = resource.get_info() assert len(info) == 3 assert info["directory"] == directory assert info["prefix"] == "/st" assert all([type(r) is web.ResourceRoute for r in info["routes"].values()]) async def test_system_route_get_info(router: web.UrlDispatcher) -> None: handler = make_handler() router.add_route("GET", "/", handler) req = make_mocked_request("GET", "/abc") info = await router.resolve(req) assert info.get_info()["http_exception"].status == 404 def test_resources_view_len(router: web.UrlDispatcher) -> None: router.add_resource("/plain") router.add_resource("/variable/{name}") assert 2 == len(router.resources()) def test_resources_view_iter(router: web.UrlDispatcher) -> None: resource1 = router.add_resource("/plain") resource2 = router.add_resource("/variable/{name}") resources = [resource1, resource2] assert list(resources) == list(router.resources()) def test_resources_view_contains(router: web.UrlDispatcher) -> None: resource1 = router.add_resource("/plain") resource2 = router.add_resource("/variable/{name}") resources = [resource1, resource2] for resource in resources: assert resource in router.resources() def test_resources_abc(router: web.UrlDispatcher) -> None: assert isinstance(router.resources(), Sized) assert isinstance(router.resources(), Iterable) assert isinstance(router.resources(), Container) def test_static_route_user_home(router: web.UrlDispatcher) -> None: here = pathlib.Path(aiohttp.__file__).parent try: static_dir = pathlib.Path("~") / here.relative_to(pathlib.Path.home()) except ValueError: pytest.skip("aiohttp folder is not placed in user's HOME") route = router.add_static("/st", str(static_dir)) assert here == route.get_info()["directory"] def test_static_route_points_to_file(router: web.UrlDispatcher) -> None: here = pathlib.Path(aiohttp.__file__).parent / "__init__.py" with pytest.raises(ValueError): router.add_static("/st", here) async def test_404_for_static_resource(router: web.UrlDispatcher) -> None: resource = router.add_static("/st", pathlib.Path(aiohttp.__file__).parent) ret = await resource.resolve(make_mocked_request("GET", "/unknown/path")) assert (None, set()) == ret async def test_405_for_resource_adapter(router: web.UrlDispatcher) -> None: resource = router.add_static("/st", pathlib.Path(aiohttp.__file__).parent) ret = await resource.resolve(make_mocked_request("POST", "/st/abc.py")) assert (None, {"HEAD", "GET"}) == ret @pytest.mark.skipif(platform.system() == "Windows", reason="Different path formats") async def test_static_resource_outside_traversal(router: web.UrlDispatcher) -> None: """Test relative path traversing outside root does not resolve.""" static_file = pathlib.Path(aiohttp.__file__) request_path = "/st" + "/.." * (len(static_file.parts) - 2) + str(static_file) assert pathlib.Path(request_path).resolve() == static_file resource = router.add_static("/st", static_file.parent) ret = await resource.resolve(make_mocked_request("GET", request_path)) # Should not resolve, otherwise filesystem information may be leaked. assert (None, set()) == ret async def test_check_allowed_method_for_found_resource( router: web.UrlDispatcher, ) -> None: handler = make_handler() resource = router.add_resource("/") resource.add_route("GET", handler) ret = await resource.resolve(make_mocked_request("GET", "/")) assert ret[0] is not None assert {"GET"} == ret[1] def test_url_for_in_static_resource(router: web.UrlDispatcher) -> None: resource = router.add_static("/static", pathlib.Path(aiohttp.__file__).parent) assert URL("/static/file.txt") == resource.url_for(filename="file.txt") def test_url_for_in_static_resource_pathlib(router: web.UrlDispatcher) -> None: resource = router.add_static("/static", pathlib.Path(aiohttp.__file__).parent) assert URL("/static/file.txt") == resource.url_for( filename=pathlib.Path("file.txt") ) def test_url_for_in_resource_route(router: web.UrlDispatcher) -> None: route = router.add_route("GET", "/get/{name}", make_handler(), name="name") assert URL("/get/John") == route.url_for(name="John") def test_subapp_get_info(app: web.Application) -> None: subapp = web.Application() resource = subapp.add_subapp("/pre", subapp) assert resource.get_info() == {"prefix": "/pre", "app": subapp} @pytest.mark.parametrize( "domain,error", [ (None, TypeError), ("", ValueError), ("http://dom", ValueError), ("*.example.com", ValueError), ("example$com", ValueError), ], ) def test_domain_validation_error(domain: str | None, error: type[Exception]) -> None: with pytest.raises(error): Domain(domain) # type: ignore[arg-type] def test_domain_valid() -> None: assert Domain("example.com:81").canonical == "example.com:81" assert MaskDomain("*.example.com").canonical == r".*\.example\.com" assert Domain("пуни.код").canonical == "xn--h1ajfq.xn--d1alm" @pytest.mark.parametrize( "a,b,result", [ ("example.com", "example.com", True), ("example.com:81", "example.com:81", True), ("example.com:81", "example.com", False), ("пуникод", "xn--d1ahgkhc2a", True), ("*.example.com", "jpg.example.com", True), ("*.example.com", "a.example.com", True), ("*.example.com", "example.com", False), ], ) def test_match_domain(a: str, b: str, result: bool) -> None: if "*" in a: rule: Domain = MaskDomain(a) else: rule = Domain(a) assert rule.match_domain(b) is result def test_add_subapp_errors(app: web.Application) -> None: with pytest.raises(TypeError): app.add_subapp(1, web.Application()) # type: ignore[arg-type] def test_subapp_rule_resource(app: web.Application) -> None: subapp = web.Application() subapp.router.add_get("/", make_handler()) rule = Domain("example.com") assert rule.get_info() == {"domain": "example.com"} resource = app.add_domain("example.com", subapp) assert resource.canonical == "example.com" assert resource.get_info() == {"rule": resource._rule, "app": subapp} resource.add_prefix("/a") resource.raw_match("/b") assert len(resource) assert list(resource) assert repr(resource).startswith(" None: app = web.Application() with pytest.raises(TypeError): app.add_domain(1, app) # type: ignore[arg-type] async def test_add_domain( app: web.Application, loop: asyncio.AbstractEventLoop ) -> None: subapp1 = web.Application() h1 = make_handler() subapp1.router.add_get("/", h1) app.add_domain("example.com", subapp1) subapp2 = web.Application() h2 = make_handler() subapp2.router.add_get("/", h2) app.add_domain("*.example.com", subapp2) subapp3 = web.Application() h3 = make_handler() subapp3.router.add_get("/", h3) app.add_domain("*", subapp3) request = make_mocked_request("GET", "/", {"host": "example.com"}) match_info = await app.router.resolve(request) assert match_info.route.handler is h1 request = make_mocked_request("GET", "/", {"host": "a.example.com"}) match_info = await app.router.resolve(request) assert match_info.route.handler is h2 request = make_mocked_request("GET", "/", {"host": "example2.com"}) match_info = await app.router.resolve(request) assert match_info.route.handler is h3 request = make_mocked_request("POST", "/", {"host": "example.com"}) match_info = await app.router.resolve(request) assert isinstance(match_info.http_exception, web.HTTPMethodNotAllowed) def test_subapp_url_for(app: web.Application) -> None: subapp = web.Application() resource = app.add_subapp("/pre", subapp) with pytest.raises(RuntimeError): resource.url_for() def test_subapp_repr(app: web.Application) -> None: subapp = web.Application() resource = app.add_subapp("/pre", subapp) assert repr(resource).startswith(" None: subapp = web.Application() subapp.router.add_get("/", make_handler(), allow_head=False) subapp.router.add_post("/", make_handler()) resource = app.add_subapp("/pre", subapp) assert len(resource) == 2 def test_subapp_iter(app: web.Application) -> None: subapp = web.Application() r1 = subapp.router.add_get("/", make_handler(), allow_head=False) r2 = subapp.router.add_post("/", make_handler()) resource = app.add_subapp("/pre", subapp) assert list(resource) == [r1, r2] @pytest.mark.parametrize( "route_name", ( "invalid name", "class", ), ) def test_invalid_route_name(router: web.UrlDispatcher, route_name: str) -> None: with pytest.raises(ValueError): router.add_get("/", make_handler(), name=route_name) def test_frozen_router(router: web.UrlDispatcher) -> None: router.freeze() with pytest.raises(RuntimeError): router.add_get("/", make_handler()) def test_frozen_router_subapp(app: web.Application) -> None: subapp = web.Application() subapp.freeze() with pytest.raises(RuntimeError): app.add_subapp("/pre", subapp) def test_frozen_app_on_subapp(app: web.Application) -> None: app.freeze() subapp = web.Application() with pytest.raises(RuntimeError): app.add_subapp("/pre", subapp) def test_set_options_route(router: web.UrlDispatcher) -> None: resource = router.add_static("/static", pathlib.Path(aiohttp.__file__).parent) assert all(r.method != "OPTIONS" for r in resource) resource.set_options_route(make_handler()) assert any(r.method == "OPTIONS" for r in resource) with pytest.raises(RuntimeError): resource.set_options_route(make_handler()) def test_dynamic_url_with_name_started_from_underscore( router: web.UrlDispatcher, ) -> None: route = router.add_route("GET", "/get/{_name}", make_handler()) assert URL("/get/John") == route.url_for(_name="John") def test_cannot_add_subapp_with_empty_prefix(app: web.Application) -> None: subapp = web.Application() with pytest.raises(ValueError): app.add_subapp("", subapp) def test_cannot_add_subapp_with_slash_prefix(app: web.Application) -> None: subapp = web.Application() with pytest.raises(ValueError): app.add_subapp("/", subapp) async def test_convert_empty_path_to_slash_on_freezing( router: web.UrlDispatcher, ) -> None: handler = make_handler() route = router.add_get("", handler) resource = route.resource assert resource is not None assert resource.get_info() == {"path": ""} router.freeze() assert resource.get_info() == {"path": "/"} def test_plain_resource_canonical() -> None: canonical = "/plain/path" res = web.PlainResource(path=canonical) assert res.canonical == canonical def test_dynamic_resource_canonical() -> None: canonicals = { "/get/{name}": "/get/{name}", r"/get/{num:^\d+}": "/get/{num}", r"/handler/{to:\d+}": r"/handler/{to}", r"/{one}/{two:.+}": r"/{one}/{two}", } for pattern, canonical in canonicals.items(): res = web.DynamicResource(path=pattern) assert res.canonical == canonical def test_static_resource_canonical() -> None: prefix = "/prefix" directory = str(pathlib.Path(aiohttp.__file__).parent) canonical = prefix res = web.StaticResource(prefix=prefix, directory=directory) assert res.canonical == canonical def test_prefixed_subapp_resource_canonical(app: web.Application) -> None: canonical = "/prefix" subapp = web.Application() res = subapp.add_subapp(canonical, subapp) assert res.canonical == canonical async def test_prefixed_subapp_overlap(app: web.Application) -> None: # Subapp should not overshadow other subapps with overlapping prefixes subapp1 = web.Application() handler1 = make_handler() subapp1.router.add_get("/a", handler1) app.add_subapp("/s", subapp1) subapp2 = web.Application() handler2 = make_handler() subapp2.router.add_get("/b", handler2) app.add_subapp("/ss", subapp2) subapp3 = web.Application() handler3 = make_handler() subapp3.router.add_get("/c", handler3) app.add_subapp("/s/s", subapp3) match_info = await app.router.resolve(make_mocked_request("GET", "/s/a")) assert match_info.route.handler is handler1 match_info = await app.router.resolve(make_mocked_request("GET", "/ss/b")) assert match_info.route.handler is handler2 match_info = await app.router.resolve(make_mocked_request("GET", "/s/s/c")) assert match_info.route.handler is handler3 async def test_prefixed_subapp_empty_route(app: web.Application) -> None: subapp = web.Application() handler = make_handler() subapp.router.add_get("", handler) app.add_subapp("/s", subapp) match_info = await app.router.resolve(make_mocked_request("GET", "/s")) assert match_info.route.handler is handler match_info = await app.router.resolve(make_mocked_request("GET", "/s/")) assert "" == repr(match_info) async def test_prefixed_subapp_root_route(app: web.Application) -> None: subapp = web.Application() handler = make_handler() subapp.router.add_get("/", handler) app.add_subapp("/s", subapp) match_info = await app.router.resolve(make_mocked_request("GET", "/s/")) assert match_info.route.handler is handler match_info = await app.router.resolve(make_mocked_request("GET", "/s")) assert "" == repr(match_info) ================================================ FILE: tests/test_web_app.py ================================================ import asyncio import sys from collections.abc import AsyncIterator, Callable, Iterator from contextlib import asynccontextmanager from typing import NoReturn from unittest import mock import pytest from aiohttp import log, web from aiohttp.pytest_plugin import AiohttpClient from aiohttp.typedefs import Handler async def test_app_ctor() -> None: app = web.Application() assert app.logger is log.web_logger def test_app_call() -> None: app = web.Application() assert app is app() async def test_app_register_on_finish() -> None: app = web.Application() cb1 = mock.AsyncMock(return_value=None) cb2 = mock.AsyncMock(return_value=None) app.on_cleanup.append(cb1) app.on_cleanup.append(cb2) app.freeze() await app.cleanup() cb1.assert_called_once_with(app) cb2.assert_called_once_with(app) async def test_app_register_coro() -> None: app = web.Application() fut = asyncio.get_event_loop().create_future() async def cb(app: web.Application) -> None: await asyncio.sleep(0.001) fut.set_result(123) app.on_cleanup.append(cb) app.freeze() await app.cleanup() assert fut.done() assert 123 == fut.result() def test_logging() -> None: logger = mock.Mock() app = web.Application() app.logger = logger assert app.logger is logger async def test_on_shutdown() -> None: app = web.Application() called = False async def on_shutdown(app_param: web.Application) -> None: nonlocal called assert app is app_param called = True app.on_shutdown.append(on_shutdown) app.freeze() await app.shutdown() assert called async def test_on_startup() -> None: app = web.Application() long_running1_called = False long_running2_called = False all_long_running_called = False async def long_running1(app_param: web.Application) -> None: nonlocal long_running1_called assert app is app_param long_running1_called = True async def long_running2(app_param: web.Application) -> None: nonlocal long_running2_called assert app is app_param long_running2_called = True async def on_startup_all_long_running(app_param: web.Application) -> None: nonlocal all_long_running_called assert app is app_param all_long_running_called = True await asyncio.gather(long_running1(app_param), long_running2(app_param)) app.on_startup.append(on_startup_all_long_running) app.freeze() await app.startup() assert long_running1_called assert long_running2_called assert all_long_running_called def test_appkey() -> None: key = web.AppKey("key", str) app = web.Application() app[key] = "value" assert app[key] == "value" assert len(app) == 1 del app[key] assert len(app) == 0 def test_appkey_repr_concrete() -> None: key = web.AppKey("key", int) assert repr(key) in ( "", # pytest-xdist "", ) key2 = web.AppKey("key", web.Request) assert repr(key2) in ( # pytest-xdist: "", "", ) def test_appkey_repr_nonconcrete() -> None: key = web.AppKey("key", Iterator[int]) if sys.version_info < (3, 11): assert repr(key) in ( # pytest-xdist: "", "", ) else: assert repr(key) in ( # pytest-xdist: "", "", ) def test_appkey_repr_annotated() -> None: key = web.AppKey[Iterator[int]]("key") if sys.version_info < (3, 11): assert repr(key) in ( # pytest-xdist: "", "", ) else: assert repr(key) in ( # pytest-xdist: "", "", ) def test_app_str_keys() -> None: app = web.Application() with pytest.warns( UserWarning, match=r"web_advanced\.html#application-s-config" ) as checker: app["key"] = "value" # Check that the error is emitted at the call site (stacklevel=2) assert checker[0].filename == __file__ assert app["key"] == "value" def test_app_get() -> None: key = web.AppKey("key", int) app = web.Application() assert app.get(key, "foo") == "foo" app[key] = 5 assert app.get(key, "foo") == 5 def test_app_freeze() -> None: app = web.Application() subapp = mock.Mock() subapp._middlewares = () app._subapps.append(subapp) app.freeze() assert subapp.freeze.called app.freeze() assert len(subapp.freeze.call_args_list) == 1 def test_equality() -> None: app1 = web.Application() app2 = web.Application() assert app1 == app1 assert app1 != app2 def test_app_run_middlewares() -> None: root = web.Application() sub = web.Application() root.add_subapp("/sub", sub) root.freeze() assert root._run_middlewares is False async def middleware(request: web.Request, handler: Handler) -> web.StreamResponse: assert False root = web.Application(middlewares=[middleware]) sub = web.Application() root.add_subapp("/sub", sub) root.freeze() assert root._run_middlewares is True root = web.Application() sub = web.Application(middlewares=[middleware]) root.add_subapp("/sub", sub) root.freeze() assert root._run_middlewares is True def test_subapp_pre_frozen_after_adding() -> None: app = web.Application() subapp = web.Application() app.add_subapp("/prefix", subapp) assert subapp.pre_frozen assert not subapp.frozen def test_app_inheritance() -> None: with pytest.raises(TypeError): class A(web.Application): # type: ignore[misc] pass def test_app_custom_attr() -> None: app = web.Application() with pytest.raises(AttributeError): app.custom = None # type: ignore[attr-defined] async def test_cleanup_ctx() -> None: app = web.Application() out = [] def f(num: int) -> Callable[[web.Application], AsyncIterator[None]]: async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) return inner app.cleanup_ctx.append(f(1)) app.cleanup_ctx.append(f(2)) app.freeze() await app.startup() assert out == ["pre_1", "pre_2"] await app.cleanup() assert out == ["pre_1", "pre_2", "post_2", "post_1"] async def test_cleanup_ctx_exception_on_startup() -> None: app = web.Application() out = [] exc = Exception("fail") def f( num: int, fail: bool = False ) -> Callable[[web.Application], AsyncIterator[None]]: async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) if fail: raise exc yield None out.append("post_" + str(num)) return inner app.cleanup_ctx.append(f(1)) app.cleanup_ctx.append(f(2, True)) app.cleanup_ctx.append(f(3)) app.freeze() with pytest.raises(Exception) as ctx: await app.startup() assert ctx.value is exc assert out == ["pre_1", "pre_2"] await app.cleanup() assert out == ["pre_1", "pre_2", "post_1"] async def test_cleanup_ctx_exception_on_cleanup() -> None: app = web.Application() out = [] exc = Exception("fail") def f( num: int, fail: bool = False ) -> Callable[[web.Application], AsyncIterator[None]]: async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) if fail: raise exc return inner app.cleanup_ctx.append(f(1)) app.cleanup_ctx.append(f(2, True)) app.cleanup_ctx.append(f(3)) app.freeze() await app.startup() assert out == ["pre_1", "pre_2", "pre_3"] with pytest.raises(Exception) as ctx: await app.cleanup() assert ctx.value is exc assert out == ["pre_1", "pre_2", "pre_3", "post_3", "post_2", "post_1"] async def test_cleanup_ctx_cleanup_after_exception() -> None: app = web.Application() ctx_state = None async def success_ctx(app: web.Application) -> AsyncIterator[None]: nonlocal ctx_state ctx_state = "START" yield ctx_state = "CLEAN" async def fail_ctx(app: web.Application) -> AsyncIterator[NoReturn]: raise Exception() yield # type: ignore[unreachable] # pragma: no cover app.cleanup_ctx.append(success_ctx) app.cleanup_ctx.append(fail_ctx) runner = web.AppRunner(app) try: with pytest.raises(Exception): await runner.setup() finally: await runner.cleanup() assert ctx_state == "CLEAN" @pytest.mark.parametrize("exc_cls", (Exception, asyncio.CancelledError)) async def test_cleanup_ctx_exception_on_cleanup_multiple( exc_cls: type[BaseException], ) -> None: app = web.Application() out = [] def f( num: int, fail: bool = False ) -> Callable[[web.Application], AsyncIterator[None]]: async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) if fail: raise exc_cls("fail_" + str(num)) return inner app.cleanup_ctx.append(f(1)) app.cleanup_ctx.append(f(2, True)) app.cleanup_ctx.append(f(3, True)) app.freeze() await app.startup() assert out == ["pre_1", "pre_2", "pre_3"] with pytest.raises(web.CleanupError) as ctx: await app.cleanup() exc = ctx.value assert len(exc.exceptions) == 2 assert str(exc.exceptions[0]) == "fail_3" assert str(exc.exceptions[1]) == "fail_2" assert out == ["pre_1", "pre_2", "pre_3", "post_3", "post_2", "post_1"] async def test_cleanup_ctx_multiple_yields() -> None: app = web.Application() out = [] def f(num: int) -> Callable[[web.Application], AsyncIterator[None]]: async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) yield None return inner app.cleanup_ctx.append(f(1)) app.freeze() await app.startup() assert out == ["pre_1"] with pytest.raises(RuntimeError): await app.cleanup() assert out == ["pre_1", "post_1"] async def test_cleanup_ctx_with_async_generator_and_asynccontextmanager() -> None: entered = [] async def gen_ctx(app: web.Application) -> AsyncIterator[None]: entered.append("enter-gen") try: yield finally: entered.append("exit-gen") @asynccontextmanager async def cm_ctx(app: web.Application) -> AsyncIterator[None]: entered.append("enter-cm") try: yield finally: entered.append("exit-cm") app = web.Application() app.cleanup_ctx.append(gen_ctx) app.cleanup_ctx.append(cm_ctx) app.freeze() await app.startup() assert "enter-gen" in entered and "enter-cm" in entered await app.cleanup() assert "exit-gen" in entered and "exit-cm" in entered async def test_cleanup_ctx_exception_in_cm_exit() -> None: app = web.Application() exc = RuntimeError("exit failed") @asynccontextmanager async def failing_exit_ctx(app: web.Application) -> AsyncIterator[None]: yield raise exc app.cleanup_ctx.append(failing_exit_ctx) app.freeze() await app.startup() with pytest.raises(RuntimeError) as ctx: await app.cleanup() assert ctx.value is exc async def test_cleanup_ctx_mixed_with_exception_in_cm_exit() -> None: app = web.Application() out = [] async def working_gen(app: web.Application) -> AsyncIterator[None]: out.append("pre_gen") yield out.append("post_gen") exc = RuntimeError("cm exit failed") @asynccontextmanager async def failing_exit_cm(app: web.Application) -> AsyncIterator[None]: out.append("pre_cm") yield out.append("post_cm") raise exc app.cleanup_ctx.append(working_gen) app.cleanup_ctx.append(failing_exit_cm) app.freeze() await app.startup() with pytest.raises(RuntimeError) as ctx: await app.cleanup() assert ctx.value is exc assert out == ["pre_gen", "pre_cm", "post_cm", "post_gen"] async def test_subapp_chained_config_dict_visibility( aiohttp_client: AiohttpClient, ) -> None: key1 = web.AppKey("key1", str) key2 = web.AppKey("key2", str) async def main_handler(request: web.Request) -> web.Response: assert request.config_dict[key1] == "val1" assert key2 not in request.config_dict return web.Response(status=200) root = web.Application() root[key1] = "val1" root.add_routes([web.get("/", main_handler)]) async def sub_handler(request: web.Request) -> web.Response: assert request.config_dict[key1] == "val1" assert request.config_dict[key2] == "val2" return web.Response(status=201) sub = web.Application() sub[key2] = "val2" sub.add_routes([web.get("/", sub_handler)]) root.add_subapp("/sub", sub) client = await aiohttp_client(root) resp = await client.get("/") assert resp.status == 200 resp = await client.get("/sub/") assert resp.status == 201 async def test_subapp_chained_config_dict_overriding( aiohttp_client: AiohttpClient, ) -> None: key = web.AppKey("key", str) async def main_handler(request: web.Request) -> web.Response: assert request.config_dict[key] == "val1" return web.Response(status=200) root = web.Application() root[key] = "val1" root.add_routes([web.get("/", main_handler)]) async def sub_handler(request: web.Request) -> web.Response: assert request.config_dict[key] == "val2" return web.Response(status=201) sub = web.Application() sub[key] = "val2" sub.add_routes([web.get("/", sub_handler)]) root.add_subapp("/sub", sub) client = await aiohttp_client(root) resp = await client.get("/") assert resp.status == 200 resp = await client.get("/sub/") assert resp.status == 201 async def test_subapp_on_startup(aiohttp_client: AiohttpClient) -> None: subapp = web.Application() startup = web.AppKey("startup", bool) cleanup = web.AppKey("cleanup", bool) startup_called = False async def on_startup(app: web.Application) -> None: nonlocal startup_called startup_called = True app[startup] = True subapp.on_startup.append(on_startup) ctx_pre_called = False ctx_post_called = False async def cleanup_ctx(app: web.Application) -> AsyncIterator[None]: nonlocal ctx_pre_called, ctx_post_called ctx_pre_called = True app[cleanup] = True yield None ctx_post_called = True subapp.cleanup_ctx.append(cleanup_ctx) shutdown_called = False async def on_shutdown(app: web.Application) -> None: nonlocal shutdown_called shutdown_called = True subapp.on_shutdown.append(on_shutdown) cleanup_called = False async def on_cleanup(app: web.Application) -> None: nonlocal cleanup_called cleanup_called = True subapp.on_cleanup.append(on_cleanup) app = web.Application() app.add_subapp("/subapp", subapp) assert not startup_called assert not ctx_pre_called assert not ctx_post_called assert not shutdown_called assert not cleanup_called assert subapp.on_startup.frozen assert subapp.cleanup_ctx.frozen assert subapp.on_shutdown.frozen assert subapp.on_cleanup.frozen assert subapp.router.frozen client = await aiohttp_client(app) assert startup_called assert ctx_pre_called # type: ignore[unreachable] assert not ctx_post_called assert not shutdown_called assert not cleanup_called await client.close() assert startup_called assert ctx_pre_called assert ctx_post_called assert shutdown_called assert cleanup_called @pytest.mark.filterwarnings(r"ignore:.*web\.AppKey:UserWarning") def test_app_iter() -> None: app = web.Application() b = web.AppKey("b", str) c = web.AppKey("c", str) app["a"] = "0" app[b] = "1" app[c] = "2" app["d"] = "4" assert sorted(list(app)) == [b, c, "a", "d"] def test_app_forbid_nonslot_attr() -> None: app = web.Application() with pytest.raises(AttributeError): app.unknow_attr # type: ignore[attr-defined] with pytest.raises(AttributeError): app.unknow_attr = 1 # type: ignore[attr-defined] def test_forbid_changing_frozen_app() -> None: app = web.Application() app.freeze() with pytest.raises(RuntimeError): app["key"] = "value" def test_app_boolean() -> None: app = web.Application() assert app ================================================ FILE: tests/test_web_cli.py ================================================ import sys from unittest import mock import pytest from pytest_mock import MockerFixture from aiohttp import web def test_entry_func_empty(mocker: MockerFixture) -> None: error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) argv = [""] with pytest.raises(SystemExit): web.main(argv) error.assert_called_with("'entry-func' not in 'module:function' syntax") def test_entry_func_only_module(mocker: MockerFixture) -> None: argv = ["test"] error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) error.assert_called_with("'entry-func' not in 'module:function' syntax") def test_entry_func_only_function(mocker: MockerFixture) -> None: argv = [":test"] error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) error.assert_called_with("'entry-func' not in 'module:function' syntax") def test_entry_func_only_separator(mocker: MockerFixture) -> None: argv = [":"] error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) error.assert_called_with("'entry-func' not in 'module:function' syntax") def test_entry_func_relative_module(mocker: MockerFixture) -> None: argv = [".a.b:c"] error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) error.assert_called_with("relative module names not supported") def test_entry_func_non_existent_module(mocker: MockerFixture) -> None: argv = ["alpha.beta:func"] mocker.patch("aiohttp.web.import_module", side_effect=ImportError("Test Error")) error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) error.assert_called_with("unable to import alpha.beta: Test Error") def test_entry_func_non_existent_attribute(mocker: MockerFixture) -> None: argv = ["alpha.beta:func"] import_module = mocker.patch("aiohttp.web.import_module") error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) module = import_module("alpha.beta") del module.func with pytest.raises(SystemExit): web.main(argv) error.assert_called_with( "module {!r} has no attribute {!r}".format("alpha.beta", "func") ) @pytest.mark.skipif(sys.platform.startswith("win32"), reason="Windows not Unix") def test_path_no_host(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch) -> None: argv = "--path=test_path.sock alpha.beta:func".split() mocker.patch("aiohttp.web.import_module") run_app = mocker.patch("aiohttp.web.run_app") with pytest.raises(SystemExit): web.main(argv) run_app.assert_called_with(mock.ANY, path="test_path.sock", host=None, port=None) @pytest.mark.skipif(sys.platform.startswith("win32"), reason="Windows not Unix") def test_path_and_host(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch) -> None: argv = "--path=test_path.sock --host=localhost --port=8000 alpha.beta:func".split() mocker.patch("aiohttp.web.import_module") run_app = mocker.patch("aiohttp.web.run_app") with pytest.raises(SystemExit): web.main(argv) run_app.assert_called_with( mock.ANY, path="test_path.sock", host="localhost", port=8000 ) def test_path_when_unsupported( mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch ) -> None: argv = "--path=test_path.sock alpha.beta:func".split() mocker.patch("aiohttp.web.import_module") monkeypatch.delattr("socket.AF_UNIX", raising=False) error = mocker.patch("aiohttp.web.ArgumentParser.error", side_effect=SystemExit) with pytest.raises(SystemExit): web.main(argv) error.assert_called_with( "file system paths not supported by your operating environment" ) def test_entry_func_call(mocker: MockerFixture) -> None: mocker.patch("aiohttp.web.run_app") import_module = mocker.patch("aiohttp.web.import_module") argv = ( "-H testhost -P 6666 --extra-optional-eins alpha.beta:func " "--extra-optional-zwei extra positional args" ).split() module = import_module("alpha.beta") with pytest.raises(SystemExit): web.main(argv) module.func.assert_called_with( ("--extra-optional-eins --extra-optional-zwei extra positional args").split() ) def test_running_application(mocker: MockerFixture) -> None: run_app = mocker.patch("aiohttp.web.run_app") import_module = mocker.patch("aiohttp.web.import_module") exit = mocker.patch("aiohttp.web.ArgumentParser.exit", side_effect=SystemExit) argv = ( "-H testhost -P 6666 --extra-optional-eins alpha.beta:func " "--extra-optional-zwei extra positional args" ).split() module = import_module("alpha.beta") app = module.func() with pytest.raises(SystemExit): web.main(argv) run_app.assert_called_with(app, host="testhost", port=6666, path=None) exit.assert_called_with(message="Stopped\n") ================================================ FILE: tests/test_web_exceptions.py ================================================ import collections import pickle from collections.abc import Mapping from traceback import format_exception from typing import NoReturn import pytest from yarl import URL from aiohttp import web from aiohttp.pytest_plugin import AiohttpClient def test_all_http_exceptions_exported() -> None: assert "HTTPException" in web.__all__ for name in dir(web): if name.startswith("_"): continue obj = getattr(web, name) if isinstance(obj, type) and issubclass(obj, web.HTTPException): assert name in web.__all__ async def test_ctor() -> None: resp = web.HTTPOk() assert resp.text == "200: OK" compare: Mapping[str, str] = {"Content-Type": "text/plain"} assert resp.headers == compare assert resp.reason == "OK" assert resp.status == 200 assert bool(resp) async def test_ctor_with_headers() -> None: resp = web.HTTPOk(headers={"X-Custom": "value"}) assert resp.text == "200: OK" compare: Mapping[str, str] = {"Content-Type": "text/plain", "X-Custom": "value"} assert resp.headers == compare assert resp.reason == "OK" assert resp.status == 200 async def test_ctor_content_type() -> None: resp = web.HTTPOk(text="text", content_type="custom") assert resp.text == "text" compare: Mapping[str, str] = {"Content-Type": "custom"} assert resp.headers == compare assert resp.reason == "OK" assert resp.status == 200 assert bool(resp) async def test_ctor_content_type_without_text() -> None: with pytest.deprecated_call( match=r"^content_type without text is deprecated since " r"4\.0 and scheduled for removal in 5\.0 \(#3462\)$", ): resp = web.HTTPResetContent(content_type="custom") assert resp.text is None compare: Mapping[str, str] = {"Content-Type": "custom"} assert resp.headers == compare assert resp.reason == "Reset Content" assert resp.status == 205 assert bool(resp) async def test_ctor_text_for_empty_body() -> None: with pytest.deprecated_call( match=r"^text argument is deprecated for HTTP status 205 since " r"4\.0 and scheduled for removal in 5\.0 \(#3462\),the " r"response should be provided without a body$", ): resp = web.HTTPResetContent(text="text") assert resp.text == "text" compare: Mapping[str, str] = {"Content-Type": "text/plain"} assert resp.headers == compare assert resp.reason == "Reset Content" assert resp.status == 205 def test_terminal_classes_has_status_code() -> None: terminals = set() for name in dir(web): obj = getattr(web, name) if isinstance(obj, type) and issubclass(obj, web.HTTPException): terminals.add(obj) dup = frozenset(terminals) for cls1 in dup: for cls2 in dup: if cls1 in cls2.__bases__: terminals.discard(cls1) for cls in terminals: assert cls.status_code is not None codes = collections.Counter(cls.status_code for cls in terminals) assert None not in codes assert 1 == codes.most_common(1)[0][1] def test_with_text() -> None: resp = web.HTTPNotFound(text="Page not found") assert 404 == resp.status assert "Page not found" == resp.text assert "text/plain" == resp.headers["Content-Type"] def test_default_text() -> None: resp = web.HTTPOk() assert "200: OK" == resp.text def test_empty_text_204() -> None: resp = web.HTTPNoContent() assert resp.text is None def test_empty_text_205() -> None: resp = web.HTTPResetContent() assert resp.text is None def test_empty_text_304() -> None: resp = web.HTTPNoContent() resp.text is None def test_no_link_451() -> None: with pytest.raises(TypeError): web.HTTPUnavailableForLegalReasons() # type: ignore[call-arg] def test_link_none_451() -> None: resp = web.HTTPUnavailableForLegalReasons(link=None) assert resp.link is None assert "Link" not in resp.headers def test_link_empty_451() -> None: resp = web.HTTPUnavailableForLegalReasons(link="") assert resp.link is None assert "Link" not in resp.headers def test_link_str_451() -> None: resp = web.HTTPUnavailableForLegalReasons(link="http://warning.or.kr/") assert resp.link == URL("http://warning.or.kr/") assert resp.headers["Link"] == '; rel="blocked-by"' def test_link_url_451() -> None: resp = web.HTTPUnavailableForLegalReasons(link=URL("http://warning.or.kr/")) assert resp.link == URL("http://warning.or.kr/") assert resp.headers["Link"] == '; rel="blocked-by"' def test_link_CRLF_451() -> None: resp = web.HTTPUnavailableForLegalReasons(link="http://warning.or.kr/\r\n") assert "\r\n" not in resp.headers["Link"] def test_HTTPException_retains_cause() -> None: with pytest.raises(web.HTTPException) as ei: try: raise Exception("CustomException") except Exception as exc: raise web.HTTPException() from exc tb = "".join(format_exception(ei.type, ei.value, ei.tb)) assert "CustomException" in tb assert "direct cause" in tb class TestHTTPOk: def test_ctor_all(self) -> None: resp = web.HTTPOk( headers={"X-Custom": "value"}, reason="Done", text="text", content_type="custom", ) assert resp.text == "text" compare: Mapping[str, str] = {"X-Custom": "value", "Content-Type": "custom"} assert resp.headers == compare assert resp.reason == "Done" assert resp.status == 200 def test_multiline_reason(self) -> None: with pytest.raises(ValueError, match=r"Reason cannot contain"): web.HTTPOk(reason="Bad\r\nInjected-header: foo") def test_reason_with_cr(self) -> None: with pytest.raises(ValueError, match=r"Reason cannot contain"): web.HTTPOk(reason="OK\rSet-Cookie: evil=1") def test_reason_with_lf(self) -> None: with pytest.raises(ValueError, match=r"Reason cannot contain"): web.HTTPOk(reason="OK\nSet-Cookie: evil=1") def test_pickle(self) -> None: resp = web.HTTPOk( headers={"X-Custom": "value"}, reason="Done", text="text", content_type="custom", ) resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) assert resp2.text == "text" assert resp2.headers == resp.headers assert resp2.reason == "Done" assert resp2.status == 200 assert resp2.foo == "bar" async def test_app(self, aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: raise web.HTTPOk() app = web.Application() app.router.add_get("/", handler) cli = await aiohttp_client(app) resp = await cli.get("/") assert 200 == resp.status txt = await resp.text() assert "200: OK" == txt class TestHTTPFound: def test_location_str(self) -> None: exc = web.HTTPFound(location="/redirect") assert exc.location == URL("/redirect") assert exc.headers["Location"] == "/redirect" def test_location_url(self) -> None: exc = web.HTTPFound(location=URL("/redirect")) assert exc.location == URL("/redirect") assert exc.headers["Location"] == "/redirect" def test_empty_location(self) -> None: with pytest.raises(ValueError): web.HTTPFound(location="") with pytest.raises(ValueError): web.HTTPFound(location=None) # type: ignore[arg-type] def test_location_CRLF(self) -> None: exc = web.HTTPFound(location="/redirect\r\n") assert "\r\n" not in exc.headers["Location"] def test_pickle(self) -> None: resp = web.HTTPFound( location="http://example.com", headers={"X-Custom": "value"}, reason="Wow", text="text", content_type="custom", ) resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) assert resp2.location == URL("http://example.com") assert resp2.text == "text" assert resp2.headers == resp.headers assert resp2.reason == "Wow" assert resp2.status == 302 assert resp2.foo == "bar" async def test_app(self, aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: raise web.HTTPFound(location="/redirect") app = web.Application() app.router.add_get("/", handler) cli = await aiohttp_client(app) resp = await cli.get("/", allow_redirects=False) assert 302 == resp.status txt = await resp.text() assert "302: Found" == txt assert "/redirect" == resp.headers["location"] class TestHTTPMethodNotAllowed: async def test_ctor(self) -> None: resp = web.HTTPMethodNotAllowed( "GET", ["POST", "PUT"], headers={"X-Custom": "value"}, reason="Unsupported", text="text", content_type="custom", ) assert resp.method == "GET" assert resp.allowed_methods == {"POST", "PUT"} assert resp.text == "text" compare: Mapping[str, str] = { "X-Custom": "value", "Content-Type": "custom", "Allow": "POST,PUT", } assert resp.headers == compare assert resp.reason == "Unsupported" assert resp.status == 405 def test_pickle(self) -> None: resp = web.HTTPMethodNotAllowed( method="GET", allowed_methods=("POST", "PUT"), headers={"X-Custom": "value"}, reason="Unsupported", text="text", content_type="custom", ) resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) assert resp2.method == "GET" assert resp2.allowed_methods == {"POST", "PUT"} assert resp2.text == "text" assert resp2.headers == resp.headers assert resp2.reason == "Unsupported" assert resp2.status == 405 assert resp2.foo == "bar" class TestHTTPRequestEntityTooLarge: def test_ctor(self) -> None: resp = web.HTTPRequestEntityTooLarge( max_size=100, actual_size=123, headers={"X-Custom": "value"}, reason="Too large", ) assert resp.text == ( "Maximum request body size 100 exceeded, actual body size 123" ) compare: Mapping[str, str] = {"X-Custom": "value", "Content-Type": "text/plain"} assert resp.headers == compare assert resp.reason == "Too large" assert resp.status == 413 def test_pickle(self) -> None: resp = web.HTTPRequestEntityTooLarge( 100, actual_size=123, headers={"X-Custom": "value"}, reason="Too large" ) resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) assert resp2.text == resp.text assert resp2.headers == resp.headers assert resp2.reason == "Too large" assert resp2.status == 413 assert resp2.foo == "bar" class TestHTTPUnavailableForLegalReasons: def test_ctor(self) -> None: exc = web.HTTPUnavailableForLegalReasons( link="http://warning.or.kr/", headers={"X-Custom": "value"}, reason="Zaprescheno", text="text", content_type="custom", ) assert exc.link == URL("http://warning.or.kr/") assert exc.text == "text" compare: Mapping[str, str] = { "X-Custom": "value", "Content-Type": "custom", "Link": '; rel="blocked-by"', } assert exc.headers == compare assert exc.reason == "Zaprescheno" assert exc.status == 451 def test_no_link(self) -> None: with pytest.raises(TypeError): web.HTTPUnavailableForLegalReasons() # type: ignore[call-arg] def test_none_link(self) -> None: exc = web.HTTPUnavailableForLegalReasons(link=None) assert exc.link is None assert "Link" not in exc.headers def test_empty_link(self) -> None: exc = web.HTTPUnavailableForLegalReasons(link="") assert exc.link is None assert "Link" not in exc.headers def test_link_str(self) -> None: exc = web.HTTPUnavailableForLegalReasons(link="http://warning.or.kr/") assert exc.link == URL("http://warning.or.kr/") assert exc.headers["Link"] == '; rel="blocked-by"' def test_link_url(self) -> None: exc = web.HTTPUnavailableForLegalReasons(link=URL("http://warning.or.kr/")) assert exc.link == URL("http://warning.or.kr/") assert exc.headers["Link"] == '; rel="blocked-by"' def test_link_CRLF(self) -> None: exc = web.HTTPUnavailableForLegalReasons(link="http://warning.or.kr/\r\n") assert "\r\n" not in exc.headers["Link"] def test_pickle(self) -> None: resp = web.HTTPUnavailableForLegalReasons( link="http://warning.or.kr/", headers={"X-Custom": "value"}, reason="Zaprescheno", text="text", content_type="custom", ) resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) assert resp2.link == URL("http://warning.or.kr/") assert resp2.text == "text" assert resp2.headers == resp.headers assert resp2.reason == "Zaprescheno" assert resp2.status == 451 assert resp2.foo == "bar" ================================================ FILE: tests/test_web_functional.py ================================================ import asyncio import io import json import pathlib import socket import sys from collections.abc import AsyncIterator, Awaitable, Callable, Generator from typing import NoReturn from unittest import mock import pytest from multidict import CIMultiDictProxy, MultiDict from pytest_mock import MockerFixture from yarl import URL import aiohttp from aiohttp import ( FormData, HttpVersion, HttpVersion10, HttpVersion11, TraceConfig, multipart, web, ) from aiohttp.abc import AbstractResolver, ResolveResult from aiohttp.compression_utils import ZLibBackend, ZLibCompressObjProtocol from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.typedefs import Handler, Middleware from aiohttp.web_protocol import RequestHandler try: import brotlicffi as brotli except ImportError: import brotli try: import ssl except ImportError: ssl = None # type: ignore[assignment] @pytest.fixture def here() -> pathlib.Path: return pathlib.Path(__file__).parent @pytest.fixture def fname(here: pathlib.Path) -> pathlib.Path: return here / "conftest.py" def new_dummy_form() -> FormData: form = FormData() form.add_field("name", b"123") return form async def test_simple_get(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"OK") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status txt = await resp.text() assert "OK" == txt resp.release() async def test_simple_get_with_text(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(text="OK", headers={"content-type": "text/plain"}) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status txt = await resp.text() assert "OK" == txt resp.release() async def test_handler_returns_not_response( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient ) -> None: asyncio.get_event_loop().set_debug(True) logger = mock.Mock() async def handler(request: web.Request) -> str: return "abc" app = web.Application() app.router.add_get("/", handler) # type: ignore[arg-type] server = await aiohttp_server(app, logger=logger) client = await aiohttp_client(server) async with client.get("/") as resp: assert resp.status == 500 async def test_handler_returns_none( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient ) -> None: asyncio.get_event_loop().set_debug(True) logger = mock.Mock() async def handler(request: web.Request) -> None: return None app = web.Application() app.router.add_get("/", handler) # type: ignore[arg-type] server = await aiohttp_server(app, logger=logger) client = await aiohttp_client(server) async with client.get("/") as resp: assert resp.status == 500 async def test_handler_returns_not_response_after_100expect( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> NoReturn: raise Exception("foo") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/", expect100=True) as resp: assert resp.status == 500 async def test_head_returns_empty_body(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"test") app = web.Application() app.router.add_head("/", handler) client = await aiohttp_client(app, version=HttpVersion11) resp = await client.head("/") assert 200 == resp.status txt = await resp.text() assert "" == txt # The Content-Length header should be set to 4 which is # the length of the response body if it would have been # returned by a GET request. assert resp.headers["Content-Length"] == "4" @pytest.mark.parametrize("status", (201, 204, 404)) async def test_default_content_type_no_body( aiohttp_client: AiohttpClient, status: int ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(status=status) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == status assert await resp.read() == b"" assert "Content-Type" not in resp.headers async def test_response_before_complete(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"OK") app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) data = b"0" * 1024 * 1024 resp = await client.post("/", data=data) assert 200 == resp.status text = await resp.text() assert "OK" == text resp.release() @pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()") async def test_cancel_shutdown(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: t = asyncio.create_task(request.protocol.shutdown()) # Ensure it's started waiting await asyncio.sleep(0) t.cancel() # Cancellation should not be suppressed with pytest.raises(asyncio.CancelledError): await t # Repeat for second waiter in shutdown() with mock.patch.object(request.protocol, "_request_in_progress", False): with mock.patch.object(request.protocol, "_current_request", None): t = asyncio.create_task(request.protocol.shutdown()) await asyncio.sleep(0) t.cancel() with pytest.raises(asyncio.CancelledError): await t return web.Response(body=b"OK") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) async with client.get("/") as resp: assert resp.status == 200 txt = await resp.text() assert txt == "OK" async def test_post_form(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() assert {"a": "1", "b": "2", "c": ""} == data return web.Response(body=b"OK") app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data={"a": "1", "b": "2", "c": ""}) assert 200 == resp.status txt = await resp.text() assert "OK" == txt resp.release() async def test_post_text(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.text() assert "русский" == data data2 = await request.text() assert data == data2 return web.Response(text=data) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data="русский") assert 200 == resp.status txt = await resp.text() assert "русский" == txt resp.release() async def test_post_json(aiohttp_client: AiohttpClient) -> None: dct = {"key": "текст"} async def handler(request: web.Request) -> web.Response: data = await request.json() assert dct == data data2 = await request.json(loads=json.loads) assert data == data2 resp = web.Response() resp.content_type = "application/json" resp.body = json.dumps(data).encode("utf8") return resp app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) headers = {"Content-Type": "application/json"} resp = await client.post("/", data=json.dumps(dct), headers=headers) assert 200 == resp.status data = await resp.json() assert dct == data resp.release() async def test_multipart(aiohttp_client: AiohttpClient) -> None: with multipart.MultipartWriter() as writer: writer.append("test") writer.append_json({"passed": True}) async def handler(request: web.Request) -> web.Response: reader = await request.multipart() assert isinstance(reader, multipart.MultipartReader) part = await reader.next() assert isinstance(part, multipart.BodyPartReader) thing = await part.text() assert thing == "test" part = await reader.next() assert isinstance(part, multipart.BodyPartReader) assert part.headers["Content-Type"] == "application/json" json_thing = await part.json() assert json_thing == {"passed": True} resp = web.Response() resp.content_type = "application/json" resp.body = b"" return resp app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data=writer) assert 200 == resp.status resp.release() async def test_multipart_empty(aiohttp_client: AiohttpClient) -> None: with multipart.MultipartWriter() as writer: pass async def handler(request: web.Request) -> web.Response: reader = await request.multipart() assert isinstance(reader, multipart.MultipartReader) async for part in reader: assert False, f"Unexpected part found in reader: {part!r}" return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data=writer) assert 200 == resp.status resp.release() async def test_multipart_content_transfer_encoding( aiohttp_client: AiohttpClient, ) -> None: # For issue #1168 with multipart.MultipartWriter() as writer: writer.append( b"\x00" * 10, headers={"Content-Transfer-Encoding": "binary"}, ) async def handler(request: web.Request) -> web.Response: reader = await request.multipart() assert isinstance(reader, multipart.MultipartReader) part = await reader.next() assert isinstance(part, multipart.BodyPartReader) assert part.headers["Content-Transfer-Encoding"] == "binary" thing = await part.read() assert thing == b"\x00" * 10 resp = web.Response() resp.content_type = "application/json" resp.body = b"" return resp app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data=writer) assert 200 == resp.status resp.release() async def test_render_redirect(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: raise web.HTTPMovedPermanently(location="/path") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/", allow_redirects=False) assert 301 == resp.status txt = await resp.text() assert "301: Moved Permanently" == txt assert "/path" == resp.headers["location"] resp.release() async def test_post_single_file(aiohttp_client: AiohttpClient) -> None: here = pathlib.Path(__file__).parent def check_file(fs: aiohttp.web_request.FileField) -> None: fullname = here / fs.filename with fullname.open("rb") as f: test_data = f.read() data = fs.file.read() assert test_data == data async def handler(request: web.Request) -> web.Response: data = await request.post() assert ["data.unknown_mime_type"] == list(data.keys()) for fs in data.values(): assert isinstance(fs, aiohttp.web_request.FileField) await asyncio.to_thread(check_file, fs) fs.file.close() resp = web.Response(body=b"OK") return resp app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) fname = here / "data.unknown_mime_type" with fname.open("rb") as fd: resp = await client.post("/", data=[fd]) assert 200 == resp.status resp.release() async def test_files_upload_with_same_key(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() files = data.getall("file") file_names = set() for _file in files: assert isinstance(_file, aiohttp.web_request.FileField) assert not _file.file.closed if _file.filename == "test1.jpeg": assert await asyncio.to_thread(_file.file.read) == b"binary data 1" if _file.filename == "test2.jpeg": assert await asyncio.to_thread(_file.file.read) == b"binary data 2" file_names.add(_file.filename) _file.file.close() assert len(files) == 2 assert file_names == {"test1.jpeg", "test2.jpeg"} resp = web.Response(body=b"OK") return resp app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) data = FormData() data.add_field( "file", b"binary data 1", content_type="image/jpeg", filename="test1.jpeg" ) data.add_field( "file", b"binary data 2", content_type="image/jpeg", filename="test2.jpeg" ) resp = await client.post("/", data=data) assert 200 == resp.status resp.release() async def test_post_files(aiohttp_client: AiohttpClient) -> None: here = pathlib.Path(__file__).parent def check_file(fs: aiohttp.web_request.FileField) -> None: fullname = here / fs.filename with fullname.open("rb") as f: test_data = f.read() data = fs.file.read() assert test_data == data async def handler(request: web.Request) -> web.Response: data = await request.post() assert ["data.unknown_mime_type", "conftest.py"] == list(data.keys()) for fs in data.values(): assert isinstance(fs, aiohttp.web_request.FileField) await asyncio.to_thread(check_file, fs) fs.file.close() resp = web.Response(body=b"OK") return resp app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with (here / "data.unknown_mime_type").open("rb") as f1: with (here / "conftest.py").open("rb") as f2: resp = await client.post("/", data=[f1, f2]) assert 200 == resp.status resp.release() async def test_release_post_data(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: await request.release() chunk = await request.content.readany() assert chunk == b"" return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data="post text") assert 200 == resp.status resp.release() async def test_post_form_with_duplicate_keys(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() lst = list(data.items()) assert [("a", "1"), ("a", "2")] == lst return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data=MultiDict([("a", "1"), ("a", "2")])) assert 200 == resp.status resp.release() def test_repr_for_application() -> None: app = web.Application() assert f"" == repr(app) async def test_expect_default_handler_unknown(aiohttp_client: AiohttpClient) -> None: # Test default Expect handler for unknown Expect value. # A server that does not understand or is unable to comply with any of # the expectation values in the Expect field of a request MUST respond # with appropriate error status. The server MUST respond with a 417 # (Expectation Failed) status if any of the expectations cannot be met # or, if there are other problems with the request, some other 4xx # status. # http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.20 async def handler(request: web.Request) -> web.Response: assert False app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", headers={"Expect": "SPAM"}) assert 417 == resp.status resp.release() async def test_100_continue(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() assert b"123" == data["name"] return web.Response() form = FormData() form.add_field("name", b"123") app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data=form, expect100=True) assert 200 == resp.status resp.release() async def test_100_continue_custom(aiohttp_client: AiohttpClient) -> None: expect_received = False async def handler(request: web.Request) -> web.Response: data = await request.post() assert b"123" == data["name"] return web.Response() async def expect_handler(request: web.Request) -> None: nonlocal expect_received expect_received = True assert request.version == HttpVersion11 await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") app = web.Application() app.router.add_post("/", handler, expect_handler=expect_handler) client = await aiohttp_client(app) resp = await client.post("/", data=new_dummy_form(), expect100=True) assert 200 == resp.status assert expect_received resp.release() async def test_100_continue_custom_response(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: data = await request.post() assert b"123", data["name"] return web.Response() async def expect_handler(request: web.Request) -> None: assert request.version == HttpVersion11 if auth_err: raise web.HTTPForbidden() await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") app = web.Application() app.router.add_post("/", handler, expect_handler=expect_handler) client = await aiohttp_client(app) auth_err = False resp = await client.post("/", data=new_dummy_form(), expect100=True) assert 200 == resp.status resp.release() auth_err = True resp = await client.post("/", data=new_dummy_form(), expect100=True) assert 403 == resp.status resp.release() async def test_expect_handler_custom_response(aiohttp_client: AiohttpClient) -> None: cache = {"foo": "bar"} async def handler(request: web.Request) -> web.Response: return web.Response(text="handler") async def expect_handler(request: web.Request) -> web.Response | None: k = request.headers["X-Key"] cached_value = cache.get(k) return web.Response(text=cached_value) if cached_value else None app = web.Application() # expect_handler is only typed on add_route(). app.router.add_route("POST", "/", handler, expect_handler=expect_handler) client = await aiohttp_client(app) async with client.post("/", expect100=True, headers={"X-Key": "foo"}) as resp: assert resp.status == 200 assert await resp.text() == "bar" async with client.post("/", expect100=True, headers={"X-Key": "spam"}) as resp: assert resp.status == 200 assert await resp.text() == "handler" async def test_100_continue_for_not_found(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) resp = await client.post("/not_found", data="data", expect100=True) assert 404 == resp.status resp.release() async def test_100_continue_for_not_allowed(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.get("/", expect100=True) assert 405 == resp.status resp.release() async def test_http11_keep_alive_default(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion11) resp = await client.get("/") assert 200 == resp.status assert resp.version == HttpVersion11 assert "Connection" not in resp.headers resp.release() async def test_http10_keep_alive_default(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion10) async with client.get("/") as resp: assert 200 == resp.status assert resp.version == HttpVersion10 assert resp.headers["Connection"] == "keep-alive" async def test_http10_keep_alive_with_headers_close( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: await request.read() return web.Response(body=b"OK") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion10) headers = {"Connection": "close"} resp = await client.get("/", headers=headers) assert 200 == resp.status assert resp.version == HttpVersion10 assert "Connection" not in resp.headers resp.release() async def test_http10_keep_alive_with_headers(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: await request.read() return web.Response(body=b"OK") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion10) headers = {"Connection": "keep-alive"} resp = await client.get("/", headers=headers) assert 200 == resp.status assert resp.version == HttpVersion10 assert resp.headers["Connection"] == "keep-alive" resp.release() async def test_upload_file(aiohttp_client: AiohttpClient) -> None: here = pathlib.Path(__file__).parent fname = here / "aiohttp.png" with fname.open("rb") as f: data = f.read() async def handler(request: web.Request) -> web.Response: form = await request.post() form_file = form["file"] assert isinstance(form_file, aiohttp.web_request.FileField) raw_data = await asyncio.to_thread(form_file.file.read) form_file.file.close() assert data == raw_data return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data={"file": io.BytesIO(data)}) assert 200 == resp.status resp.release() async def test_upload_file_object(aiohttp_client: AiohttpClient) -> None: here = pathlib.Path(__file__).parent fname = here / "aiohttp.png" with fname.open("rb") as f: data = f.read() async def handler(request: web.Request) -> web.Response: form = await request.post() form_file = form["file"] assert isinstance(form_file, aiohttp.web_request.FileField) raw_data = await asyncio.to_thread(form_file.file.read) form_file.file.close() assert data == raw_data return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) with fname.open("rb") as f: resp = await client.post("/", data={"file": f}) assert 200 == resp.status resp.release() @pytest.mark.parametrize( "method", ["get", "post", "options", "post", "put", "patch", "delete"] ) async def test_empty_content_for_query_without_body( method: str, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.Response: assert not request.body_exists assert not request.can_read_body return web.Response() app = web.Application() app.router.add_route(method, "/", handler) client = await aiohttp_client(app) resp = await client.request(method, "/") assert 200 == resp.status async def test_empty_content_for_query_with_body(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.body_exists assert request.can_read_body body = await request.read() return web.Response(body=body) app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data=b"data") assert 200 == resp.status resp.release() async def test_get_with_empty_arg(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert "arg" in request.query assert "" == request.query["arg"] return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/?arg") assert 200 == resp.status resp.release() async def test_large_header(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) headers = {"Long-Header": "ab" * 8129} resp = await client.get("/", headers=headers) assert 400 == resp.status resp.release() async def test_large_header_allowed( aiohttp_client: AiohttpClient, aiohttp_server: AiohttpServer ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_post("/", handler) server = await aiohttp_server(app, max_field_size=81920) client = await aiohttp_client(server) headers = {"Long-Header": "ab" * 8129} resp = await client.post("/", headers=headers) assert 200 == resp.status resp.release() async def test_get_with_empty_arg_with_equal(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert "arg" in request.query assert "" == request.query["arg"] return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/?arg=") assert 200 == resp.status resp.release() async def test_response_with_async_gen( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: with fname.open("rb") as f: data = f.read() data_size = len(data) async def stream(f_name: pathlib.Path) -> AsyncIterator[bytes]: with f_name.open("rb") as f: data = await asyncio.to_thread(f.read, 100) while data: yield data data = await asyncio.to_thread(f.read, 100) async def handler(request: web.Request) -> web.Response: headers = {"Content-Length": str(data_size)} return web.Response(body=stream(fname), headers=headers) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status resp_data = await resp.read() assert resp_data == data assert resp.headers.get("Content-Length") == str(len(resp_data)) resp.release() async def test_response_with_async_gen_no_params( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: with fname.open("rb") as f: data = f.read() data_size = len(data) async def stream() -> AsyncIterator[bytes]: with fname.open("rb") as f: data = await asyncio.to_thread(f.read, 100) while data: yield data data = await asyncio.to_thread(f.read, 100) async def handler(request: web.Request) -> web.Response: headers = {"Content-Length": str(data_size)} return web.Response(body=stream(), headers=headers) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status resp_data = await resp.read() assert resp_data == data assert resp.headers.get("Content-Length") == str(len(resp_data)) resp.release() async def test_response_with_file( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: outer_file_descriptor = None with fname.open("rb") as f: data = f.read() async def handler(request: web.Request) -> web.Response: nonlocal outer_file_descriptor outer_file_descriptor = fname.open("rb") return web.Response(body=outer_file_descriptor) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status resp_data = await resp.read() expected_content_disposition = 'attachment; filename="conftest.py"' assert resp_data == data assert resp.headers.get("Content-Type") in ( "application/octet-stream", "text/x-python", "text/plain", ) assert resp.headers.get("Content-Length") == str(len(resp_data)) assert resp.headers.get("Content-Disposition") == expected_content_disposition resp.release() assert outer_file_descriptor is not None outer_file_descriptor.close() async def test_response_with_file_ctype( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: outer_file_descriptor = None with fname.open("rb") as f: data = f.read() async def handler(request: web.Request) -> web.Response: nonlocal outer_file_descriptor outer_file_descriptor = fname.open("rb") return web.Response( body=outer_file_descriptor, headers={"content-type": "text/binary"} ) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status resp_data = await resp.read() expected_content_disposition = 'attachment; filename="conftest.py"' assert resp_data == data assert resp.headers.get("Content-Type") == "text/binary" assert resp.headers.get("Content-Length") == str(len(resp_data)) assert resp.headers.get("Content-Disposition") == expected_content_disposition resp.release() assert outer_file_descriptor is not None outer_file_descriptor.close() async def test_response_with_payload_disp( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: outer_file_descriptor = None with fname.open("rb") as f: data = f.read() async def handler(request: web.Request) -> web.Response: nonlocal outer_file_descriptor outer_file_descriptor = fname.open("rb") pl = aiohttp.get_payload(outer_file_descriptor) pl.set_content_disposition("inline", filename="test.txt") return web.Response(body=pl, headers={"content-type": "text/binary"}) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status resp_data = await resp.read() assert resp_data == data assert resp.headers.get("Content-Type") == "text/binary" assert resp.headers.get("Content-Length") == str(len(resp_data)) assert resp.headers.get("Content-Disposition") == 'inline; filename="test.txt"' resp.release() assert outer_file_descriptor is not None outer_file_descriptor.close() async def test_response_with_payload_stringio( aiohttp_client: AiohttpClient, fname: pathlib.Path ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=io.StringIO("test")) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status resp_data = await resp.read() assert resp_data == b"test" resp.release() @pytest.fixture(params=["gzip", "deflate", "deflate-raw"]) def compressor_case( request: pytest.FixtureRequest, parametrize_zlib_backend: None, ) -> Generator[tuple[ZLibCompressObjProtocol, str], None, None]: encoding: str = request.param max_wbits: int = ZLibBackend.MAX_WBITS encoding_to_wbits: dict[str, int] = { "deflate": max_wbits, "deflate-raw": -max_wbits, "gzip": 16 + max_wbits, } compressor = ZLibBackend.compressobj(wbits=encoding_to_wbits[encoding]) yield (compressor, "deflate" if encoding.startswith("deflate") else encoding) async def test_response_with_precompressed_body( aiohttp_client: AiohttpClient, compressor_case: tuple[ZLibCompressObjProtocol, str], ) -> None: compressor, encoding = compressor_case async def handler(request: web.Request) -> web.Response: headers = {"Content-Encoding": encoding} data = compressor.compress(b"mydata") + compressor.flush() return web.Response(body=data, headers=headers) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status data = await resp.read() assert b"mydata" == data assert resp.headers.get("Content-Encoding") == encoding resp.release() async def test_response_with_precompressed_body_brotli( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: headers = {"Content-Encoding": "br"} return web.Response(body=brotli.compress(b"mydata"), headers=headers) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status data = await resp.read() assert b"mydata" == data assert resp.headers.get("Content-Encoding") == "br" resp.release() async def test_bad_request_payload(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.method == "POST" with pytest.raises(aiohttp.web.RequestPayloadError): await request.content.read() return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data=b"test", headers={"content-encoding": "gzip"}) assert 200 == resp.status resp.release() async def test_stream_response_multiple_chunks(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() resp.enable_chunked_encoding() await resp.prepare(request) await resp.write(b"x") await resp.write(b"y") await resp.write(b"z") return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status data = await resp.read() assert b"xyz" == data resp.release() async def test_start_without_routes(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) resp = await client.get("/") assert 404 == resp.status resp.release() async def test_requests_count(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) assert client.server.handler.requests_count == 0 resp = await client.get("/") assert 200 == resp.status assert client.server.handler.requests_count == 1 resp.release() resp = await client.get("/") assert 200 == resp.status assert client.server.handler.requests_count == 2 resp.release() resp = await client.get("/") assert 200 == resp.status assert client.server.handler.requests_count == 3 resp.release() async def test_redirect_url(aiohttp_client: AiohttpClient) -> None: async def redirector(request: web.Request) -> NoReturn: raise web.HTTPFound(location=URL("/redirected")) async def redirected(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_get("/redirector", redirector) app.router.add_get("/redirected", redirected) client = await aiohttp_client(app) resp = await client.get("/redirector") assert resp.status == 200 resp.release() async def test_simple_subapp(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path", subapp) client = await aiohttp_client(app) resp = await client.get("/path/to") assert resp.status == 200 txt = await resp.text() assert "OK" == txt resp.release() async def test_subapp_reverse_url(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: raise web.HTTPMovedPermanently(location=subapp.router["name"].url_for()) async def handler2(request: web.Request) -> web.Response: return web.Response(text="OK") app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) subapp.router.add_get("/final", handler2, name="name") app.add_subapp("/path", subapp) client = await aiohttp_client(app) resp = await client.get("/path/to") assert resp.status == 200 txt = await resp.text() assert "OK" == txt assert resp.url.path == "/path/final" resp.release() async def test_subapp_reverse_variable_url(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: raise web.HTTPMovedPermanently( location=subapp.router["name"].url_for(part="final") ) async def handler2(request: web.Request) -> web.Response: return web.Response(text="OK") app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) subapp.router.add_get("/{part}", handler2, name="name") app.add_subapp("/path", subapp) client = await aiohttp_client(app) resp = await client.get("/path/to") assert resp.status == 200 txt = await resp.text() assert "OK" == txt assert resp.url.path == "/path/final" resp.release() async def test_subapp_reverse_static_url(aiohttp_client: AiohttpClient) -> None: fname = "aiohttp.png" async def handler(request: web.Request) -> NoReturn: raise web.HTTPMovedPermanently( location=subapp.router["name"].url_for(filename=fname) ) app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) here = pathlib.Path(__file__).parent subapp.router.add_static("/static", here, name="name") app.add_subapp("/path", subapp) client = await aiohttp_client(app) resp = await client.get("/path/to") assert resp.url.path == "/path/static/" + fname assert resp.status == 200 body = await resp.read() resp.release() with (here / fname).open("rb") as f: assert body == f.read() async def test_subapp_app(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.app is subapp return web.Response(text="OK") app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) client = await aiohttp_client(app) resp = await client.get("/path/to") assert resp.status == 200 txt = await resp.text() assert "OK" == txt resp.release() async def test_subapp_not_found(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) client = await aiohttp_client(app) resp = await client.get("/path/other") assert resp.status == 404 resp.release() async def test_subapp_not_found2(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) client = await aiohttp_client(app) resp = await client.get("/invalid/other") assert resp.status == 404 resp.release() async def test_subapp_not_allowed(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: assert False app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) client = await aiohttp_client(app) resp = await client.post("/path/to") assert resp.status == 405 assert resp.headers["Allow"] == "GET,HEAD" resp.release() async def test_subapp_cannot_add_app_in_handler(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: request.match_info.add_app(app) assert False app = web.Application() subapp = web.Application() subapp.router.add_get("/to", handler) app.add_subapp("/path/", subapp) client = await aiohttp_client(app) resp = await client.get("/path/to") assert resp.status == 500 resp.release() async def test_old_style_subapp_middlewares(aiohttp_client: AiohttpClient) -> None: order = [] async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") with pytest.deprecated_call( match=r"^Middleware decorator is deprecated since 4\.0 and " r"its behaviour is default, you can simply remove " r"this decorator\.$", ): @web.middleware async def middleware( request: web.Request, handler: Handler ) -> web.StreamResponse: order.append((1, request.app[name])) resp = await handler(request) assert 200 == resp.status order.append((2, request.app[name])) return resp app = web.Application(middlewares=[middleware]) name = web.AppKey("app", str) subapp1 = web.Application(middlewares=[middleware]) subapp2 = web.Application(middlewares=[middleware]) app[name] = "app" subapp1[name] = "subapp1" subapp2[name] = "subapp2" subapp2.router.add_get("/to", handler) subapp1.add_subapp("/b/", subapp2) app.add_subapp("/a/", subapp1) client = await aiohttp_client(app) resp = await client.get("/a/b/to") assert resp.status == 200 assert [ (1, "app"), (1, "subapp1"), (1, "subapp2"), (2, "subapp2"), (2, "subapp1"), (2, "app"), ] == order resp.release() async def test_subapp_on_response_prepare(aiohttp_client: AiohttpClient) -> None: order = [] async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") def make_signal( app: web.Application, ) -> Callable[[web.Request, web.StreamResponse], Awaitable[None]]: async def on_response( request: web.Request, response: web.StreamResponse ) -> None: order.append(app) return on_response app = web.Application() app.on_response_prepare.append(make_signal(app)) subapp1 = web.Application() subapp1.on_response_prepare.append(make_signal(subapp1)) subapp2 = web.Application() subapp2.on_response_prepare.append(make_signal(subapp2)) subapp2.router.add_get("/to", handler) subapp1.add_subapp("/b/", subapp2) app.add_subapp("/a/", subapp1) client = await aiohttp_client(app) resp = await client.get("/a/b/to") assert resp.status == 200 assert [app, subapp1, subapp2] == order resp.release() async def test_subapp_on_startup(aiohttp_server: AiohttpServer) -> None: order = [] async def on_signal(app: web.Application) -> None: order.append(app) app = web.Application() app.on_startup.append(on_signal) subapp1 = web.Application() subapp1.on_startup.append(on_signal) subapp2 = web.Application() subapp2.on_startup.append(on_signal) subapp1.add_subapp("/b/", subapp2) app.add_subapp("/a/", subapp1) await aiohttp_server(app) assert [app, subapp1, subapp2] == order async def test_subapp_on_shutdown(aiohttp_server: AiohttpServer) -> None: order = [] async def on_signal(app: web.Application) -> None: order.append(app) app = web.Application() app.on_shutdown.append(on_signal) subapp1 = web.Application() subapp1.on_shutdown.append(on_signal) subapp2 = web.Application() subapp2.on_shutdown.append(on_signal) subapp1.add_subapp("/b/", subapp2) app.add_subapp("/a/", subapp1) server = await aiohttp_server(app) await server.close() assert [app, subapp1, subapp2] == order async def test_subapp_on_cleanup(aiohttp_server: AiohttpServer) -> None: order = [] async def on_signal(app: web.Application) -> None: order.append(app) app = web.Application() app.on_cleanup.append(on_signal) subapp1 = web.Application() subapp1.on_cleanup.append(on_signal) subapp2 = web.Application() subapp2.on_cleanup.append(on_signal) subapp1.add_subapp("/b/", subapp2) app.add_subapp("/a/", subapp1) server = await aiohttp_server(app) await server.close() assert [app, subapp1, subapp2] == order @pytest.mark.parametrize( "route,expected,middlewares", [ ("/sub/", ["A: root", "C: sub", "D: sub"], "AC"), ("/", ["A: root", "B: root"], "AC"), ("/sub/", ["A: root", "D: sub"], "A"), ("/", ["A: root", "B: root"], "A"), ("/sub/", ["C: sub", "D: sub"], "C"), ("/", ["B: root"], "C"), ("/sub/", ["D: sub"], ""), ("/", ["B: root"], ""), ], ) async def test_subapp_middleware_context( aiohttp_client: AiohttpClient, route: str, expected: list[str], middlewares: str ) -> None: values = [] def show_app_context(appname: str) -> Middleware: async def middleware( request: web.Request, handler: Handler ) -> web.StreamResponse: values.append(f"{appname}: {request.app[my_value]}") return await handler(request) return middleware def make_handler(appname: str) -> Handler: async def handler(request: web.Request) -> web.Response: values.append(f"{appname}: {request.app[my_value]}") return web.Response(text="Ok") return handler app = web.Application() my_value = web.AppKey("my_value", str) app[my_value] = "root" if "A" in middlewares: app.middlewares.append(show_app_context("A")) app.router.add_get("/", make_handler("B")) subapp = web.Application() subapp[my_value] = "sub" if "C" in middlewares: subapp.middlewares.append(show_app_context("C")) subapp.router.add_get("/", make_handler("D")) app.add_subapp("/sub/", subapp) client = await aiohttp_client(app) resp = await client.get(route) assert 200 == resp.status assert "Ok" == await resp.text() assert expected == values resp.release() async def test_custom_date_header(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(headers={"Date": "Sun, 30 Oct 2016 03:13:52 GMT"}) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status assert resp.headers["Date"] == "Sun, 30 Oct 2016 03:13:52 GMT" resp.release() async def test_response_prepared_with_clone(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: cloned = request.clone() resp = web.StreamResponse() await resp.prepare(cloned) return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status resp.release() async def test_app_max_client_size(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: await request.post() assert False max_size = 1024**2 app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) data = {"long_string": max_size * "x" + "xxx"} with pytest.warns(ResourceWarning): resp = await client.post("/", data=data) assert 413 == resp.status resp_text = await resp.text() assert "Maximum request body size 1048576 exceeded, actual body size" in resp_text # Maximum request body size X exceeded, actual body size X body_size = int(resp_text.split()[-1]) assert body_size >= max_size resp.release() async def test_app_max_client_size_form(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: await request.post() assert False app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) # Verify that entire multipart form can't exceed client size (not just each field). form = aiohttp.FormData() for i in range(3): form.add_field(f"f{i}", b"A" * 512000) async with client.post("/", data=form) as resp: assert resp.status == 413 resp_text = await resp.text() assert "Maximum request body size 1048576 exceeded, actual body size" in resp_text async def test_app_max_client_size_adjusted(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: await request.post() return web.Response(body=b"ok") default_max_size = 1024**2 custom_max_size = default_max_size * 2 app = web.Application(client_max_size=custom_max_size) app.router.add_post("/", handler) client = await aiohttp_client(app) data = {"long_string": default_max_size * "x" + "xxx"} with pytest.warns(ResourceWarning): resp = await client.post("/", data=data) assert 200 == resp.status resp_text = await resp.text() assert "ok" == resp_text resp.release() too_large_data = {"log_string": custom_max_size * "x" + "xxx"} with pytest.warns(ResourceWarning): resp = await client.post("/", data=too_large_data) assert 413 == resp.status resp_text = await resp.text() assert "Maximum request body size 2097152 exceeded, actual body size" in resp_text # Maximum request body size X exceeded, actual body size X body_size = int(resp_text.split()[-1]) assert body_size >= custom_max_size resp.release() async def test_app_max_client_size_none(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: await request.post() return web.Response(body=b"ok") default_max_size = 1024**2 app = web.Application(client_max_size=0) app.router.add_post("/", handler) client = await aiohttp_client(app) data = {"long_string": default_max_size * "x" + "xxx"} with pytest.warns(ResourceWarning): resp = await client.post("/", data=data) assert 200 == resp.status resp_text = await resp.text() assert "ok" == resp_text resp.release() too_large_data = {"log_string": default_max_size * 2 * "x"} with pytest.warns(ResourceWarning): resp = await client.post("/", data=too_large_data) assert 200 == resp.status resp_text = await resp.text() assert resp_text == "ok" resp.release() async def test_post_max_client_size(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: await request.post() assert False app = web.Application(client_max_size=10) app.router.add_post("/", handler) client = await aiohttp_client(app) with io.BytesIO(b"test") as file_handle: data = {"long_string": 1024 * "x", "file": file_handle} resp = await client.post("/", data=data) assert 413 == resp.status resp_text = await resp.text() assert ( "Maximum request body size 10 exceeded, " "actual body size 1024" in resp_text ) data_file = data["file"] assert isinstance(data_file, io.BytesIO) data_file.close() resp.release() async def test_post_max_client_size_for_file(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: await request.post() assert False app = web.Application(client_max_size=2) app.router.add_post("/", handler) client = await aiohttp_client(app) with io.BytesIO(b"test") as file_handle: data = {"file": file_handle} resp = await client.post("/", data=data) assert 413 == resp.status resp.release() async def test_response_with_bodypart(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: reader = await request.multipart() part = await reader.next() return web.Response(body=part) app = web.Application(client_max_size=2) app.router.add_post("/", handler) client = await aiohttp_client(app) with io.BytesIO(b"test") as file_handle: data = {"file": file_handle} resp = await client.post("/", data=data) assert 200 == resp.status body = await resp.read() assert body == b"test" disp = multipart.parse_content_disposition(resp.headers["content-disposition"]) assert disp == ("attachment", {"name": "file", "filename": "file"}) resp.release() async def test_response_with_bodypart_named( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path ) -> None: async def handler(request: web.Request) -> web.Response: reader = await request.multipart() part = await reader.next() return web.Response(body=part) app = web.Application(client_max_size=2) app.router.add_post("/", handler) client = await aiohttp_client(app) f = tmp_path / "foobar.txt" f.write_text("test", encoding="utf8") with f.open("rb") as fd: data = {"file": fd} resp = await client.post("/", data=data) assert 200 == resp.status body = await resp.read() assert body == b"test" disp = multipart.parse_content_disposition(resp.headers["content-disposition"]) assert disp == ("attachment", {"name": "file", "filename": "foobar.txt"}) resp.release() async def test_response_with_bodypart_invalid_name( aiohttp_client: AiohttpClient, ) -> None: async def handler(request: web.Request) -> web.Response: reader = await request.multipart() part = await reader.next() return web.Response(body=part) app = web.Application(client_max_size=2) app.router.add_post("/", handler) client = await aiohttp_client(app) with aiohttp.MultipartWriter() as mpwriter: mpwriter.append(b"test") resp = await client.post("/", data=mpwriter) assert 200 == resp.status body = await resp.read() assert body == b"test" assert "content-disposition" not in resp.headers resp.release() async def test_request_clone(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: r2 = request.clone(method="POST") assert r2.method == "POST" assert r2.match_info is request.match_info return web.Response() app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status resp.release() async def test_await(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse(headers={"content-length": str(4)}) await resp.prepare(request) with pytest.deprecated_call( match=r"^drain method is deprecated, use await resp\.write\(\)$", ): await resp.drain() await asyncio.sleep(0.01) await resp.write(b"test") await asyncio.sleep(0.01) await resp.write_eof() return resp app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as session: resp = await session.get(server.make_url("/")) assert resp.status == 200 assert resp.connection is not None await resp.read() resp.release() assert resp.connection is None async def test_response_context_manager(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) session = aiohttp.ClientSession() resp = await session.get(server.make_url("/")) async with resp: assert resp.status == 200 assert resp.connection is None await session.close() async def test_response_context_manager_error(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text="some text") app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) session = aiohttp.ClientSession() cm = session.get(server.make_url("/")) resp = await cm with pytest.raises(RuntimeError): async with resp: assert resp.status == 200 resp.content.set_exception(RuntimeError()) await resp.read() assert resp.closed # Wait for any pending operations to complete await resp.wait_for_close() assert session._connector is not None assert len(session._connector._conns) == 1 await session.close() async def test_client_api_context_manager(aiohttp_server: AiohttpServer) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as session: async with session.get(server.make_url("/")) as resp: assert resp.status == 200 assert resp.connection is None assert resp.connection is None async def test_context_manager_close_on_release( aiohttp_server: AiohttpServer, mocker: MockerFixture ) -> None: async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse() await resp.prepare(request) with pytest.deprecated_call( match=r"^drain method is deprecated, use await resp\.write\(\)$", ): await resp.drain() await asyncio.sleep(10) assert False app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as session: resp = await session.get(server.make_url("/")) assert resp.connection is not None proto = resp.connection._protocol mocker.spy(proto, "close") async with resp: assert resp.status == 200 assert resp.connection is not None assert resp.connection is None assert proto.close.called # type: ignore[unreachable] resp.release() # Trigger handler completion async def test_iter_any(aiohttp_server: AiohttpServer) -> None: data = b"0123456789" * 1024 async def handler(request: web.Request) -> web.Response: buf = [] async for raw in request.content.iter_any(): buf.append(raw) assert b"".join(buf) == data return web.Response() app = web.Application() app.router.add_route("POST", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as session: async with session.post(server.make_url("/"), data=data) as resp: assert resp.status == 200 async def test_request_tracing(aiohttp_server: AiohttpServer) -> None: on_request_start = mock.AsyncMock() on_request_end = mock.AsyncMock() on_dns_resolvehost_start = mock.AsyncMock() on_dns_resolvehost_end = mock.AsyncMock() on_request_redirect = mock.AsyncMock() on_connection_create_start = mock.AsyncMock() on_connection_create_end = mock.AsyncMock() async def redirector(request: web.Request) -> NoReturn: raise web.HTTPFound(location=URL("/redirected")) async def redirected(request: web.Request) -> web.Response: return web.Response() trace_config = TraceConfig() trace_config.on_request_start.append(on_request_start) trace_config.on_request_end.append(on_request_end) trace_config.on_request_redirect.append(on_request_redirect) trace_config.on_connection_create_start.append(on_connection_create_start) trace_config.on_connection_create_end.append(on_connection_create_end) trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) app = web.Application() app.router.add_get("/redirector", redirector) app.router.add_get("/redirected", redirected) server = await aiohttp_server(app) class FakeResolver(AbstractResolver): _LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1"} def __init__(self, fakes: dict[str, int]): # fakes -- dns -> port dict self._fakes = fakes self._resolver = aiohttp.DefaultResolver() async def close(self) -> None: assert False async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, ) -> list[ResolveResult]: fake_port = self._fakes.get(host) assert fake_port is not None return [ { "hostname": host, "host": self._LOCAL_HOST[family], "port": fake_port, "family": socket.AF_INET, "proto": 0, "flags": socket.AI_NUMERICHOST, } ] resolver = FakeResolver({"example.com": server.port}) connector = aiohttp.TCPConnector(resolver=resolver) client = aiohttp.ClientSession(connector=connector, trace_configs=[trace_config]) resp = await client.get("http://example.com/redirector", data="foo") assert on_request_start.called assert on_request_end.called assert on_dns_resolvehost_start.called assert on_dns_resolvehost_end.called assert on_request_redirect.called assert on_connection_create_start.called assert on_connection_create_end.called resp.release() await client.close() async def test_raise_http_exception(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: raise web.HTTPForbidden() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 403 resp.release() async def test_request_path(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert request.path_qs == "/path%20to?a=1" assert request.path == "/path to" assert request.raw_path == "/path%20to?a=1" return web.Response(body=b"OK") app = web.Application() app.router.add_get("/path to", handler) client = await aiohttp_client(app) resp = await client.get("/path to", params={"a": "1"}) assert 200 == resp.status txt = await resp.text() assert "OK" == txt resp.release() async def test_app_add_routes(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application() app.add_routes([web.get("/get", handler)]) client = await aiohttp_client(app) resp = await client.get("/get") assert resp.status == 200 resp.release() async def test_request_headers_type(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: assert isinstance(request.headers, CIMultiDictProxy) return web.Response() app = web.Application() app.add_routes([web.get("/get", handler)]) client = await aiohttp_client(app) resp = await client.get("/get") assert resp.status == 200 resp.release() async def test_signal_on_error_handler(aiohttp_client: AiohttpClient) -> None: async def on_prepare(request: web.Request, response: web.StreamResponse) -> None: response.headers["X-Custom"] = "val" app = web.Application() app.on_response_prepare.append(on_prepare) client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 404 assert resp.headers["X-Custom"] == "val" resp.release() @pytest.mark.skipif( "HttpRequestParserC" not in dir(aiohttp.http_parser), reason="C based HTTP parser not available", ) async def test_bad_method_for_c_http_parser_not_hangs( aiohttp_client: AiohttpClient, ) -> None: app = web.Application() timeout = aiohttp.ClientTimeout(sock_read=0.2) client = await aiohttp_client(app, timeout=timeout) resp = await client.request("GET1", "/") assert 400 == resp.status async def test_read_bufsize(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: ret = request.content.get_read_buffer_limits() data = await request.text() # read posted data return web.Response(text=f"{data} {ret!r}") app = web.Application(handler_args={"read_bufsize": 2}) app.router.add_post("/", handler) client = await aiohttp_client(app) resp = await client.post("/", data=b"data") assert resp.status == 200 assert await resp.text() == "data (2, 4)" resp.release() @pytest.mark.parametrize( "auto_decompress,len_of", [(True, "uncompressed"), (False, "compressed")] ) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_auto_decompress( aiohttp_client: AiohttpClient, auto_decompress: bool, len_of: str, ) -> None: async def handler(request: web.Request) -> web.Response: data = await request.read() return web.Response(text=str(len(data))) app = web.Application(handler_args={"auto_decompress": auto_decompress}) app.router.add_post("/", handler) client = await aiohttp_client(app) uncompressed = b"dataaaaaaaaaaaaaaaaaaaaaaaaa" compressor = ZLibBackend.compressobj(wbits=16 + ZLibBackend.MAX_WBITS) compressed = compressor.compress(uncompressed) + compressor.flush() assert len(compressed) != len(uncompressed) headers = {"content-encoding": "gzip"} resp = await client.post("/", data=compressed, headers=headers) assert resp.status == 200 assert await resp.text() == str(len(locals()[len_of])) resp.release() @pytest.mark.parametrize( "status", [101, 204], ) async def test_response_101_204_no_content_length_http11( status: int, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(status=status) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=HttpVersion11) resp = await client.get("/") assert CONTENT_LENGTH not in resp.headers assert TRANSFER_ENCODING not in resp.headers resp.release() async def test_stream_response_headers_204(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.StreamResponse: return web.StreamResponse(status=204) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert CONTENT_TYPE not in resp.headers assert TRANSFER_ENCODING not in resp.headers resp.release() async def test_httpfound_cookies_302(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: resp = web.HTTPFound("/") resp.set_cookie("my-cookie", "cookie-value") raise resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/", allow_redirects=False) assert "my-cookie" in resp.cookies resp.release() @pytest.mark.parametrize("status", (101, 204, 304)) @pytest.mark.parametrize("version", (HttpVersion10, HttpVersion11)) async def test_no_body_for_1xx_204_304_responses( aiohttp_client: AiohttpClient, status: int, version: HttpVersion ) -> None: """Test no body is present for for 1xx, 204, and 304 responses.""" async def handler(request: web.Request) -> web.Response: return web.Response(status=status, body=b"should not get to client") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, version=version) resp = await client.get("/") assert CONTENT_TYPE not in resp.headers assert TRANSFER_ENCODING not in resp.headers await resp.read() == b"" resp.release() async def test_keepalive_race_condition(aiohttp_client: AiohttpClient) -> None: protocol: RequestHandler[web.Request] | None = None orig_data_received = RequestHandler.data_received def delay_received(self: RequestHandler[web.Request], data: bytes) -> None: """Emulate race condition. The keepalive callback needs to be called between data_received() and when start() resumes from the waiter set within data_received(). """ orig_data_received(self, data) if protocol is None: # First request creating the keepalive connection. return assert self is protocol assert protocol._keepalive_handle is not None # Cancel existing callback that would run at some point in future. protocol._keepalive_handle.cancel() protocol._keepalive_handle = None # Set next run time into the past and run callback manually. protocol._next_keepalive_close_time = asyncio.get_running_loop().time() - 1 protocol._process_keepalive() async def handler(request: web.Request) -> web.Response: nonlocal protocol protocol = request.protocol return web.Response() target = "aiohttp.web_protocol.RequestHandler.data_received" with mock.patch(target, delay_received): app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) # Open connection, so we have a keepalive connection and reference to protocol. async with client.get("/") as resp: assert resp.status == 200 assert protocol is not None # Make 2nd request which will hit the race condition. async with client.get("/") as resp: assert resp.status == 200 async def test_keepalive_expires_on_time(aiohttp_client: AiohttpClient) -> None: """Test that the keepalive handle expires on time.""" async def handler(request: web.Request) -> web.Response: body = await request.read() assert b"" == body return web.Response(body=b"OK") app = web.Application() app.router.add_route("GET", "/", handler) connector = aiohttp.TCPConnector(limit=1) client = await aiohttp_client(app, connector=connector) loop = asyncio.get_running_loop() now = loop.time() # Patch loop time so we can control when the keepalive timeout is processed with mock.patch.object(loop, "time") as loop_time_mock: loop_time_mock.return_value = now resp1 = await client.get("/") await resp1.read() request_handler = client.server.handler.connections[0] # Ensure the keep alive handle is set assert request_handler._keepalive_handle is not None # Set the loop time to exactly the keepalive timeout loop_time_mock.return_value = request_handler._next_keepalive_close_time # sleep twice to ensure the keep alive timeout is processed await asyncio.sleep(0) await asyncio.sleep(0) # Ensure the keep alive handle expires assert request_handler._keepalive_handle is None ================================================ FILE: tests/test_web_log.py ================================================ import datetime import logging import platform import sys from contextvars import ContextVar from typing import NoReturn from unittest import mock import pytest import aiohttp from aiohttp import web from aiohttp.abc import AbstractAccessLogger, AbstractAsyncAccessLogger from aiohttp.pytest_plugin import AiohttpClient, AiohttpRawServer, AiohttpServer from aiohttp.test_utils import make_mocked_request from aiohttp.typedefs import Handler from aiohttp.web_log import AccessLogger from aiohttp.web_response import Response if sys.version_info >= (3, 11): from typing import Self else: from typing import Any as Self IS_PYPY = platform.python_implementation() == "PyPy" def test_access_logger_format() -> None: log_format = '%T "%{ETag}o" %X {X} %%P' mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, log_format) expected = '%s "%s" %%X {X} %%%s' assert expected == access_logger._log_format @pytest.mark.skipif( IS_PYPY, reason=""" Because of patching :py:class:`datetime.datetime`, under PyPy it fails in :py:func:`isinstance` call in :py:meth:`datetime.datetime.__sub__` (called from :py:meth:`aiohttp.AccessLogger._format_t`): *** TypeError: isinstance() arg 2 must be a class, type, or tuple of classes and types (Pdb) from datetime import datetime (Pdb) isinstance(now, datetime) *** TypeError: isinstance() arg 2 must be a class, type, or tuple of classes and types (Pdb) datetime.__class__ (Pdb) isinstance(now, datetime.__class__) False Ref: https://bitbucket.org/pypy/pypy/issues/1187/call-to-isinstance-in-__sub__-self-other Ref: https://github.com/celery/celery/issues/811 Ref: https://stackoverflow.com/a/46102240/595220 """, ) @pytest.mark.parametrize( "log_format,expected,extra", [ ( "%t", "[01/Jan/1843:00:29:56 +0800]", {"request_start_time": "[01/Jan/1843:00:29:56 +0800]"}, ), ( '%a %t %P %r %s %b %T %Tf %D "%{H1}i" "%{H2}i"', ( "127.0.0.2 [01/Jan/1843:00:29:56 +0800] <42> " 'GET /path HTTP/1.1 200 42 3 3.141593 3141593 "a" "b"' ), { "first_request_line": "GET /path HTTP/1.1", "process_id": "<42>", "remote_address": "127.0.0.2", "request_start_time": "[01/Jan/1843:00:29:56 +0800]", "request_time": "3", "request_time_frac": "3.141593", "request_time_micro": "3141593", "response_size": 42, "response_status": 200, "request_header": {"H1": "a", "H2": "b"}, }, ), ], ) def test_access_logger_atoms( monkeypatch: pytest.MonkeyPatch, log_format: str, expected: str, extra: dict[str, object], ) -> None: class PatchedDatetime(datetime.datetime): @classmethod def now(cls, tz: datetime.tzinfo | None = None) -> Self: assert tz is not None # Simulate: real UTC time is 1842-12-31 16:30, convert to local tz utc = datetime.datetime(1842, 12, 31, 16, 30, tzinfo=datetime.timezone.utc) local = utc.astimezone(tz) return cls( local.year, local.month, local.day, local.hour, local.minute, local.second, tzinfo=tz, ) monkeypatch.setattr("datetime.datetime", PatchedDatetime) # Mock localtime to return CST (+0800 = 28800 seconds) mock_localtime = mock.Mock() mock_localtime.return_value.tm_gmtoff = 28800 monkeypatch.setattr("aiohttp.web_log.time_mod.localtime", mock_localtime) # Clear cached timezone so it gets rebuilt with our mock AccessLogger._cached_tz = None AccessLogger._cached_tz_expires = 0.0 monkeypatch.setattr("os.getpid", lambda: 42) mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, log_format) request = mock.Mock( headers={"H1": "a", "H2": "b"}, method="GET", path_qs="/path", version=aiohttp.HttpVersion(1, 1), remote="127.0.0.2", ) response = mock.Mock(headers={}, body_length=42, status=200) access_logger.log(request, response, 3.1415926) assert not mock_logger.exception.called, mock_logger.exception.call_args mock_logger.info.assert_called_with(expected, extra=extra) @pytest.mark.skipif( IS_PYPY, reason="PyPy has issues with patching datetime.datetime", ) def test_access_logger_dst_timezone(monkeypatch: pytest.MonkeyPatch) -> None: """Test that _format_t uses the current local UTC offset, not a cached one. This ensures timestamps are correct during DST transitions. The old implementation used time.timezone which is a constant and doesn't reflect DST changes. """ # Simulate a timezone that observes DST (e.g., US Eastern) # During EST: UTC-5 (-18000s), during EDT: UTC-4 (-14400s) gmtoff_est = -18000 # UTC-5 gmtoff_edt = -14400 # UTC-4 class PatchedDatetime(datetime.datetime): @classmethod def now(cls, tz: datetime.tzinfo | None = None) -> Self: assert tz is not None # Simulate: real UTC time is 07:00, convert to local tz utc = datetime.datetime(2024, 3, 10, 7, 0, 0, tzinfo=datetime.timezone.utc) local = utc.astimezone(tz) return cls( local.year, local.month, local.day, local.hour, local.minute, local.second, tzinfo=tz, ) monkeypatch.setattr("datetime.datetime", PatchedDatetime) mock_localtime = mock.Mock() mock_localtime.return_value.tm_gmtoff = gmtoff_est monkeypatch.setattr("aiohttp.web_log.time_mod.localtime", mock_localtime) # Force cache refresh AccessLogger._cached_tz = None AccessLogger._cached_tz_expires = 0.0 mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, "%t") request = mock.Mock( headers={}, method="GET", path_qs="/", version=(1, 1), remote="127.0.0.1" ) response = mock.Mock(headers={}, body_length=0, status=200) # During EST (UTC-5): time is 07:00-05:00 = 02:00 EST access_logger.log(request, response, 0.0) call1 = mock_logger.info.call_args[0][0] assert "-0500" in call1, f"Expected EST offset in {call1}" mock_logger.reset_mock() # Switch to EDT (UTC-4): force cache invalidation mock_localtime.return_value.tm_gmtoff = gmtoff_edt AccessLogger._cached_tz = None AccessLogger._cached_tz_expires = 0.0 access_logger.log(request, response, 0.0) call2 = mock_logger.info.call_args[0][0] assert "-0400" in call2, f"Expected EDT offset in {call2}" # Verify the hour changed too (02:00 -> 03:00) assert "02:00:00 -0500" in call1 assert "03:00:00 -0400" in call2 # Verify cached tz works too assert access_logger._cached_tz is not None with mock.patch( "aiohttp.web_log.time_mod.time", return_value=access_logger._cached_tz_expires - 1, ): access_logger.log(request, response, 0.0) call3 = mock_logger.info.call_args[0][0] assert "-0400" in call3, f"Expected EDT offset in {call3}" def test_access_logger_dicts() -> None: log_format = "%{User-Agent}i %{Content-Length}o %{None}i" mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, log_format) request = mock.Mock( headers={"User-Agent": "Mock/1.0"}, version=(1, 1), remote="127.0.0.2" ) response = mock.Mock(headers={"Content-Length": 123}) access_logger.log(request, response, 0.0) assert not mock_logger.error.called expected = "Mock/1.0 123 -" extra = { "request_header": {"User-Agent": "Mock/1.0", "None": "-"}, "response_header": {"Content-Length": 123}, } mock_logger.info.assert_called_with(expected, extra=extra) def test_access_logger_unix_socket() -> None: log_format = "|%a|" mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, log_format) request = mock.Mock(headers={"User-Agent": "Mock/1.0"}, version=(1, 1), remote="") response = mock.Mock() access_logger.log(request, response, 0.0) assert not mock_logger.error.called expected = "||" mock_logger.info.assert_called_with(expected, extra={"remote_address": ""}) def test_logger_no_message() -> None: mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, "%r %{content-type}i") extra_dict = { "first_request_line": "GET / HTTP/1.1", "request_header": {"content-type": "-"}, } access_logger.log(make_mocked_request("GET", "/"), web.Response(), 0.0) mock_logger.info.assert_called_with("GET / HTTP/1.1 -", extra=extra_dict) def test_logger_internal_error() -> None: mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, "%D") access_logger.log(make_mocked_request("GET", "/"), web.Response(), "invalid") # type: ignore[arg-type] mock_logger.exception.assert_called_with("Error in logging") def test_logger_no_transport() -> None: mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, "%a") access_logger.log(make_mocked_request("GET", "/"), web.Response(), 0.0) mock_logger.info.assert_called_with("-", extra={"remote_address": "-"}) def test_logger_abc() -> None: class Logger(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: 1 / 0 mock_logger = mock.Mock() access_logger: AbstractAccessLogger = Logger(mock_logger, "") with pytest.raises(ZeroDivisionError): access_logger.log(make_mocked_request("GET", "/"), web.Response(), 0.0) class Logger2(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: self.logger.info( self.log_format.format(request=request, response=response, time=time) ) mock_logger = mock.Mock() access_logger = Logger2(mock_logger, "{request} {response} {time}") access_logger.log("request", "response", 1) # type: ignore[arg-type] mock_logger.info.assert_called_with("request response 1") async def test_exc_info_context( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: exc_msg = None class Logger(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: nonlocal exc_msg exc_msg = "{0.__name__}: {1}".format(*sys.exc_info()) async def handler(request: web.BaseRequest) -> NoReturn: raise RuntimeError("intentional runtime error") logger = mock.Mock() server = await aiohttp_raw_server(handler, access_log_class=Logger, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to", headers={"Accept": "text/html"}) assert resp.status == 500 assert exc_msg == "RuntimeError: intentional runtime error" async def test_async_logger( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: msg = None class Logger(AbstractAsyncAccessLogger): async def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: nonlocal msg msg = f"{request.path}: {response.status}" async def handler(request: web.BaseRequest) -> web.Response: return Response(text="ok") logger = mock.Mock() server = await aiohttp_raw_server(handler, access_log_class=Logger, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to", headers={"Accept": "text/html"}) assert resp.status == 200 assert msg == "/path/to: 200" async def test_contextvars_logger( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient ) -> None: VAR = ContextVar[str]("VAR") async def handler(request: web.Request) -> web.Response: return web.Response() async def middleware(request: web.Request, handler: Handler) -> web.StreamResponse: VAR.set("uuid") return await handler(request) msg = None class Logger(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: nonlocal msg msg = f"contextvars: {VAR.get()}" app = web.Application(middlewares=[middleware]) app.router.add_get("/", handler) server = await aiohttp_server(app, access_log_class=Logger) client = await aiohttp_client(server) resp = await client.get("/") assert 200 == resp.status assert msg == "contextvars: uuid" def test_access_logger_feeds_logger(caplog: pytest.LogCaptureFixture) -> None: """Test that the logger still works.""" mock_logger = logging.getLogger("test.aiohttp.log") mock_logger.setLevel(logging.INFO) access_logger = AccessLogger(mock_logger, "%b") access_logger.log( mock.Mock(name="mock_request"), mock.Mock(name="mock_response"), 42 ) assert "mock_response" in caplog.text async def test_logger_does_not_log_when_not_enabled( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient, caplog: pytest.LogCaptureFixture, ) -> None: """Test logger does nothing when not enabled.""" async def handler(request: web.Request) -> web.Response: return web.Response() class Logger(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: self.logger.critical("This should not be logged") # pragma: no cover @property def enabled(self) -> bool: return False app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app, access_log_class=Logger) client = await aiohttp_client(server) resp = await client.get("/") assert 200 == resp.status assert "This should not be logged" not in caplog.text async def test_logger_set_to_none( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient, caplog: pytest.LogCaptureFixture, ) -> None: """Test logger does nothing when access_log is set to None.""" async def handler(request: web.Request) -> web.Response: return web.Response() class Logger(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: self.logger.critical("This should not be logged") # pragma: no cover app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app, access_log=None, access_log_class=Logger) client = await aiohttp_client(server) resp = await client.get("/") assert 200 == resp.status assert "This should not be logged" not in caplog.text async def test_logger_does_not_log_when_enabled_post_init( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient, caplog: pytest.LogCaptureFixture, ) -> None: """Test logger does nothing when not enabled even if enabled post init.""" async def handler(request: web.Request) -> web.Response: return web.Response() enabled = False class Logger(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: self.logger.critical("This should not be logged") # pragma: no cover @property def enabled(self) -> bool: """Check if logger is enabled.""" # Avoid formatting the log line if it will not be emitted. return enabled app = web.Application() app.router.add_get("/", handler) server = await aiohttp_server(app, access_log_class=Logger) client = await aiohttp_client(server) resp = await client.get("/") assert 200 == resp.status assert "This should not be logged" not in caplog.text assert not server.handler.connections[0]._force_close # mock enabling logging post-init enabled = True resp = await client.get("/") assert 200 == resp.status assert "This should not be logged" not in caplog.text assert not server.handler.connections[0]._force_close ================================================ FILE: tests/test_web_middleware.py ================================================ import asyncio from collections.abc import Awaitable, Callable, Iterable from typing import NoReturn import pytest from yarl import URL from aiohttp import web, web_app from aiohttp.pytest_plugin import AiohttpClient from aiohttp.test_utils import TestClient from aiohttp.typedefs import Handler, Middleware CLI = Callable[ [Iterable[Middleware]], Awaitable[TestClient[web.Request, web.Application]] ] async def test_middleware_modifies_response( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"OK") async def middleware(request: web.Request, handler: Handler) -> web.Response: resp = await handler(request) assert 200 == resp.status resp.set_status(201) assert isinstance(resp, web.Response) assert resp.text is not None resp.text = resp.text + "[MIDDLEWARE]" return resp app = web.Application() app.middlewares.append(middleware) app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) # Call twice to verify cache works for _ in range(2): resp = await client.get("/") assert 201 == resp.status txt = await resp.text() assert "OK[MIDDLEWARE]" == txt async def test_middleware_handles_exception( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> NoReturn: raise RuntimeError("Error text") async def middleware(request: web.Request, handler: Handler) -> web.Response: with pytest.raises(RuntimeError) as ctx: await handler(request) return web.Response(status=501, text=str(ctx.value) + "[MIDDLEWARE]") app = web.Application() app.middlewares.append(middleware) app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) # Call twice to verify cache works for _ in range(2): resp = await client.get("/") assert 501 == resp.status txt = await resp.text() assert "Error text[MIDDLEWARE]" == txt async def test_middleware_chain( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") handler.annotation = "annotation_value" # type: ignore[attr-defined] async def handler2(request: web.Request) -> web.Response: return web.Response(text="OK") middleware_annotation_seen_values = [] def make_middleware(num: int) -> Middleware: async def middleware(request: web.Request, handler: Handler) -> web.Response: middleware_annotation_seen_values.append( getattr(handler, "annotation", None) ) resp = await handler(request) assert isinstance(resp, web.Response) assert resp.text is not None resp.text = resp.text + f"[{num}]" return resp return middleware app = web.Application() app.middlewares.append(make_middleware(1)) app.middlewares.append(make_middleware(2)) app.router.add_route("GET", "/", handler) app.router.add_route("GET", "/r2", handler2) client = await aiohttp_client(app) resp = await client.get("/") assert 200 == resp.status txt = await resp.text() assert "OK[2][1]" == txt assert middleware_annotation_seen_values == ["annotation_value", "annotation_value"] # check that attributes from handler are not applied to handler2 resp = await client.get("/r2") assert 200 == resp.status assert middleware_annotation_seen_values == [ "annotation_value", "annotation_value", None, None, ] async def test_middleware_subapp( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def sub_handler(request: web.Request) -> web.Response: return web.Response(text="OK") sub_handler.annotation = "annotation_value" # type: ignore[attr-defined] async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") middleware_annotation_seen_values = [] def make_middleware(num: int) -> Middleware: async def middleware( request: web.Request, handler: Handler ) -> web.StreamResponse: annotation = getattr(handler, "annotation", None) if annotation is not None: middleware_annotation_seen_values.append(f"{annotation}/{num}") return await handler(request) return middleware app = web.Application() app.middlewares.append(make_middleware(1)) app.router.add_route("GET", "/r2", handler) subapp = web.Application() subapp.middlewares.append(make_middleware(2)) subapp.router.add_route("GET", "/", sub_handler) app.add_subapp("/sub", subapp) client = await aiohttp_client(app) resp = await client.get("/sub/") assert 200 == resp.status await resp.text() assert middleware_annotation_seen_values == [ "annotation_value/1", "annotation_value/2", ] # check that attributes from sub_handler are not applied to handler del middleware_annotation_seen_values[:] resp = await client.get("/r2") assert 200 == resp.status assert middleware_annotation_seen_values == [] @pytest.fixture def cli(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> CLI: async def handler(request: web.Request) -> web.Response: return web.Response(text="OK") def wrapper( extra_middlewares: Iterable[Middleware], ) -> Awaitable[TestClient[web.Request, web.Application]]: app = web.Application() app.router.add_route("GET", "/resource1", handler) app.router.add_route("GET", "/resource2/", handler) app.router.add_route("GET", "/resource1/a/b", handler) app.router.add_route("GET", "/resource2/a/b/", handler) app.router.add_route("GET", "/resource2/a/b%2Fc/", handler) app.middlewares.extend(extra_middlewares) return aiohttp_client(app, server_kwargs={"skip_url_asserts": True}) return wrapper class TestNormalizePathMiddleware: @pytest.mark.parametrize( "path, status", [ ("/resource1", 200), ("/resource1/", 404), ("/resource2", 200), ("/resource2/", 200), ("/resource1?p1=1&p2=2", 200), ("/resource1/?p1=1&p2=2", 404), ("/resource2?p1=1&p2=2", 200), ("/resource2/?p1=1&p2=2", 200), ("/resource2/a/b%2Fc", 200), ("/resource2/a/b%2Fc/", 200), ], ) async def test_add_trailing_when_necessary( self, path: str, status: int, cli: CLI ) -> None: extra_middlewares = [web.normalize_path_middleware(merge_slashes=False)] client = await cli(extra_middlewares) resp = await client.get(path) assert resp.status == status assert resp.url.query == URL(path).query @pytest.mark.parametrize( "path, status", [ ("/resource1", 200), ("/resource1/", 200), ("/resource2", 404), ("/resource2/", 200), ("/resource1?p1=1&p2=2", 200), ("/resource1/?p1=1&p2=2", 200), ("/resource2?p1=1&p2=2", 404), ("/resource2/?p1=1&p2=2", 200), ("/resource2/a/b%2Fc", 404), ("/resource2/a/b%2Fc/", 200), ("/resource12", 404), ("/resource12345", 404), ], ) async def test_remove_trailing_when_necessary( self, path: str, status: int, cli: CLI ) -> None: extra_middlewares = [ web.normalize_path_middleware( append_slash=False, remove_slash=True, merge_slashes=False ) ] client = await cli(extra_middlewares) resp = await client.get(path) assert resp.status == status assert resp.url.query == URL(path).query @pytest.mark.parametrize( "path, status", [ ("/resource1", 200), ("/resource1/", 404), ("/resource2", 404), ("/resource2/", 200), ("/resource1?p1=1&p2=2", 200), ("/resource1/?p1=1&p2=2", 404), ("/resource2?p1=1&p2=2", 404), ("/resource2/?p1=1&p2=2", 200), ("/resource2/a/b%2Fc", 404), ("/resource2/a/b%2Fc/", 200), ], ) async def test_no_trailing_slash_when_disabled( self, path: str, status: int, cli: CLI ) -> None: extra_middlewares = [ web.normalize_path_middleware(append_slash=False, merge_slashes=False) ] client = await cli(extra_middlewares) resp = await client.get(path) assert resp.status == status assert resp.url.query == URL(path).query @pytest.mark.parametrize( "path, status", [ ("/resource1/a/b", 200), ("//resource1//a//b", 200), ("//resource1//a//b/", 404), ("///resource1//a//b", 200), ("/////resource1/a///b", 200), ("/////resource1/a//b/", 404), ("/resource1/a/b?p=1", 200), ("//resource1//a//b?p=1", 200), ("//resource1//a//b/?p=1", 404), ("///resource1//a//b?p=1", 200), ("/////resource1/a///b?p=1", 200), ("/////resource1/a//b/?p=1", 404), ], ) async def test_merge_slash(self, path: str, status: int, cli: CLI) -> None: extra_middlewares = [web.normalize_path_middleware(append_slash=False)] client = await cli(extra_middlewares) resp = await client.get(path) assert resp.status == status assert resp.url.query == URL(path).query @pytest.mark.parametrize( "path, status", [ ("/resource1/a/b", 200), ("/resource1/a/b/", 404), ("//resource2//a//b", 200), ("//resource2//a//b/", 200), ("///resource1//a//b", 200), ("///resource1//a//b/", 404), ("/////resource1/a///b", 200), ("/////resource1/a///b/", 404), ("/resource2/a/b", 200), ("//resource2//a//b", 200), ("//resource2//a//b/", 200), ("///resource2//a//b", 200), ("///resource2//a//b/", 200), ("/////resource2/a///b", 200), ("/////resource2/a///b/", 200), ("/resource1/a/b?p=1", 200), ("/resource1/a/b/?p=1", 404), ("//resource2//a//b?p=1", 200), ("//resource2//a//b/?p=1", 200), ("///resource1//a//b?p=1", 200), ("///resource1//a//b/?p=1", 404), ("/////resource1/a///b?p=1", 200), ("/////resource1/a///b/?p=1", 404), ("/resource2/a/b?p=1", 200), ("//resource2//a//b?p=1", 200), ("//resource2//a//b/?p=1", 200), ("///resource2//a//b?p=1", 200), ("///resource2//a//b/?p=1", 200), ("/////resource2/a///b?p=1", 200), ("/////resource2/a///b/?p=1", 200), ], ) async def test_append_and_merge_slash( self, path: str, status: int, cli: CLI ) -> None: extra_middlewares = [web.normalize_path_middleware()] client = await cli(extra_middlewares) resp = await client.get(path) assert resp.status == status assert resp.url.query == URL(path).query @pytest.mark.parametrize( "path, status", [ ("/resource1/a/b", 200), ("/resource1/a/b/", 200), ("//resource2//a//b", 404), ("//resource2//a//b/", 200), ("///resource1//a//b", 200), ("///resource1//a//b/", 200), ("/////resource1/a///b", 200), ("/////resource1/a///b/", 200), ("/////resource1/a///b///", 200), ("/resource2/a/b", 404), ("//resource2//a//b", 404), ("//resource2//a//b/", 200), ("///resource2//a//b", 404), ("///resource2//a//b/", 200), ("/////resource2/a///b", 404), ("/////resource2/a///b/", 200), ("/resource1/a/b?p=1", 200), ("/resource1/a/b/?p=1", 200), ("//resource2//a//b?p=1", 404), ("//resource2//a//b/?p=1", 200), ("///resource1//a//b?p=1", 200), ("///resource1//a//b/?p=1", 200), ("/////resource1/a///b?p=1", 200), ("/////resource1/a///b/?p=1", 200), ("/resource2/a/b?p=1", 404), ("//resource2//a//b?p=1", 404), ("//resource2//a//b/?p=1", 200), ("///resource2//a//b?p=1", 404), ("///resource2//a//b/?p=1", 200), ("/////resource2/a///b?p=1", 404), ("/////resource2/a///b/?p=1", 200), ], ) async def test_remove_and_merge_slash( self, path: str, status: int, cli: CLI ) -> None: extra_middlewares = [ web.normalize_path_middleware(append_slash=False, remove_slash=True) ] client = await cli(extra_middlewares) resp = await client.get(path) assert resp.status == status assert resp.url.query == URL(path).query async def test_cannot_remove_and_add_slash(self) -> None: with pytest.raises(AssertionError): web.normalize_path_middleware(append_slash=True, remove_slash=True) @pytest.mark.parametrize( ["append_slash", "remove_slash"], [ (True, False), (False, True), (False, False), ], ) async def test_open_redirects( self, append_slash: bool, remove_slash: bool, aiohttp_client: AiohttpClient ) -> None: async def handle(request: web.Request) -> web.StreamResponse: pytest.fail( "Security advisory 'GHSA-v6wp-4m6f-gcjg' test handler " "matched unexpectedly", pytrace=False, ) app = web.Application( middlewares=[ web.normalize_path_middleware( append_slash=append_slash, remove_slash=remove_slash ) ] ) app.add_routes([web.get("/", handle), web.get("/google.com", handle)]) client = await aiohttp_client(app, server_kwargs={"skip_url_asserts": True}) resp = await client.get("//google.com", allow_redirects=False) assert resp.status == 308 assert resp.headers["Location"] == "/google.com" assert resp.url.query == URL("//google.com").query async def test_bug_3669(aiohttp_client: AiohttpClient) -> None: async def paymethod(request: web.Request) -> NoReturn: assert False app = web.Application() app.router.add_route("GET", "/paymethod", paymethod) app.middlewares.append( web.normalize_path_middleware(append_slash=False, remove_slash=True) ) client = await aiohttp_client(app, server_kwargs={"skip_url_asserts": True}) resp = await client.get("/paymethods") assert resp.status == 404 assert resp.url.path != "/paymethod" async def test_old_style_middleware( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def view_handler(request: web.Request) -> web.Response: return web.Response(body=b"OK") with pytest.deprecated_call( match=r"^Middleware decorator is deprecated since 4\.0 and its " r"behaviour is default, you can simply remove this decorator\.$", ): @web.middleware async def middleware(request: web.Request, handler: Handler) -> web.Response: resp = await handler(request) assert 200 == resp.status resp.set_status(201) assert isinstance(resp, web.Response) assert resp.text is not None resp.text = resp.text + "[old style middleware]" return resp app = web.Application(middlewares=[middleware]) app.router.add_route("GET", "/", view_handler) client = await aiohttp_client(app) resp = await client.get("/") assert 201 == resp.status txt = await resp.text() assert "OK[old style middleware]" == txt async def test_new_style_middleware_class( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"OK") class Middleware: async def __call__( self, request: web.Request, handler: Handler ) -> web.Response: resp = await handler(request) assert 200 == resp.status resp.set_status(201) assert isinstance(resp, web.Response) assert resp.text is not None resp.text = resp.text + "[new style middleware]" return resp app = web.Application() app.middlewares.append(Middleware()) app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 201 == resp.status txt = await resp.text() assert "OK[new style middleware]" == txt async def test_new_style_middleware_method( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.Response: return web.Response(body=b"OK") class Middleware: async def call(self, request: web.Request, handler: Handler) -> web.Response: resp = await handler(request) assert 200 == resp.status resp.set_status(201) assert isinstance(resp, web.Response) assert resp.text is not None resp.text = resp.text + "[new style middleware]" return resp app = web.Application() app.middlewares.append(Middleware().call) app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert 201 == resp.status txt = await resp.text() assert "OK[new style middleware]" == txt async def test_middleware_does_not_leak(aiohttp_client: AiohttpClient) -> None: async def any_handler(request: web.Request) -> NoReturn: assert False class Middleware: async def call( self, request: web.Request, handler: Handler ) -> web.StreamResponse: return await handler(request) app = web.Application() app.router.add_route("POST", "/any", any_handler) app.middlewares.append(Middleware().call) client = await aiohttp_client(app) web_app._cached_build_middleware.cache_clear() for _ in range(10): resp = await client.get("/any") assert resp.status == 405 assert web_app._cached_build_middleware.cache_info().currsize < 10 ================================================ FILE: tests/test_web_protocol.py ================================================ import asyncio from typing import Any, cast from unittest import mock from aiohttp.web_protocol import RequestHandler class _DummyManager: def __init__(self) -> None: self.request_handler = mock.Mock() self.request_factory = mock.Mock() class _DummyParser: def __init__(self) -> None: self.received: list[bytes] = [] def feed_data(self, data: bytes) -> tuple[bool, bytes]: self.received.append(data) return False, b"" def test_set_parser_does_not_call_data_received_cb_for_tail( loop: asyncio.AbstractEventLoop, ) -> None: handler: RequestHandler[Any] = RequestHandler(cast(Any, _DummyManager()), loop=loop) handler._message_tail = b"tail" cb = mock.Mock() parser = _DummyParser() handler.set_parser(parser, data_received_cb=cb) cb.assert_not_called() assert parser.received == [b"tail"] def test_data_received_calls_data_received_cb( loop: asyncio.AbstractEventLoop, ) -> None: handler: RequestHandler[Any] = RequestHandler(cast(Any, _DummyManager()), loop=loop) cb = mock.Mock() parser = _DummyParser() handler.set_parser(parser, data_received_cb=cb) handler.data_received(b"x") assert cb.call_count == 1 assert parser.received == [b"x"] ================================================ FILE: tests/test_web_request.py ================================================ import asyncio import datetime import logging import socket import ssl import sys import time import weakref from collections.abc import Iterator, MutableMapping from typing import NoReturn from unittest import mock import pytest from multidict import CIMultiDict, CIMultiDictProxy, MultiDict from yarl import URL from aiohttp import ETag, HttpVersion, web from aiohttp.base_protocol import BaseProtocol from aiohttp.http_exceptions import BadHttpMessage, LineTooLong from aiohttp.http_parser import RawRequestMessage from aiohttp.pytest_plugin import AiohttpClient from aiohttp.streams import StreamReader from aiohttp.test_utils import make_mocked_request from aiohttp.web_request import _FORWARDED_PAIR_RE @pytest.fixture def protocol() -> mock.Mock: return mock.Mock(_reading_paused=False) def test_base_ctor() -> None: message = RawRequestMessage( "GET", "/path/to?a=1&b=2", HttpVersion(1, 1), CIMultiDictProxy(CIMultiDict()), (), False, None, False, False, URL("/path/to?a=1&b=2"), ) req = web.BaseRequest( message, mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock() ) assert "GET" == req.method assert HttpVersion(1, 1) == req.version # MacOS may return CamelCased host name, need .lower() # FQDN can be wider than host, e.g. # 'fv-az397-495' in 'fv-az397-495.internal.cloudapp.net' assert req.host.lower() in socket.getfqdn().lower() assert "/path/to?a=1&b=2" == req.path_qs assert "/path/to" == req.path assert "a=1&b=2" == req.query_string assert CIMultiDict() == req.headers assert () == req.raw_headers get = req.query assert MultiDict([("a", "1"), ("b", "2")]) == get # second call should return the same object assert get is req.query assert req.keep_alive assert req def test_ctor() -> None: req = make_mocked_request("GET", "/path/to?a=1&b=2") assert "GET" == req.method assert HttpVersion(1, 1) == req.version # MacOS may return CamelCased host name, need .lower() # FQDN can be wider than host, e.g. # 'fv-az397-495' in 'fv-az397-495.internal.cloudapp.net' assert req.host.lower() in socket.getfqdn().lower() assert "/path/to?a=1&b=2" == req.path_qs assert "/path/to" == req.path assert "a=1&b=2" == req.query_string assert CIMultiDict() == req.headers assert () == req.raw_headers get = req.query assert MultiDict([("a", "1"), ("b", "2")]) == get # second call should return the same object assert get is req.query assert req.keep_alive # just make sure that all lines of make_mocked_request covered headers = CIMultiDict(FOO="bar") payload = mock.Mock() protocol = mock.Mock() app = mock.Mock() req = make_mocked_request( "GET", "/path/to?a=1&b=2", headers=headers, protocol=protocol, payload=payload, app=app, ) assert req.app is app assert req.content is payload assert req.protocol is protocol assert req.transport is protocol.transport assert req.headers == headers assert req.raw_headers == ((b"FOO", b"bar"),) assert req.task is req._task def test_doubleslashes() -> None: # NB: //foo/bar is an absolute URL with foo netloc and /bar path req = make_mocked_request("GET", "/bar//foo/") assert "/bar//foo/" == req.path def test_content_type_not_specified() -> None: req = make_mocked_request("Get", "/") assert "application/octet-stream" == req.content_type def test_content_type_from_spec() -> None: req = make_mocked_request( "Get", "/", CIMultiDict([("CONTENT-TYPE", "application/json")]) ) assert "application/json" == req.content_type def test_content_type_from_spec_with_charset() -> None: req = make_mocked_request( "Get", "/", CIMultiDict([("CONTENT-TYPE", "text/html; charset=UTF-8")]) ) assert "text/html" == req.content_type assert "UTF-8" == req.charset def test_calc_content_type_on_getting_charset() -> None: req = make_mocked_request( "Get", "/", CIMultiDict([("CONTENT-TYPE", "text/html; charset=UTF-8")]) ) assert "UTF-8" == req.charset assert "text/html" == req.content_type def test_urlencoded_querystring() -> None: req = make_mocked_request("GET", "/yandsearch?text=%D1%82%D0%B5%D0%BA%D1%81%D1%82") assert {"text": "текст"} == req.query def test_non_ascii_path() -> None: req = make_mocked_request("GET", "/путь") assert "/путь" == req.path def test_non_ascii_raw_path() -> None: req = make_mocked_request("GET", "/путь") assert "/путь" == req.raw_path def test_absolute_url() -> None: req = make_mocked_request("GET", "https://example.com/path/to?a=1") assert req.url == URL("https://example.com/path/to?a=1") assert req.scheme == "https" assert req.host == "example.com" assert req.rel_url == URL.build(path="/path/to", query={"a": "1"}) def test_clone_absolute_scheme() -> None: req = make_mocked_request("GET", "https://example.com/path/to?a=1") assert req.scheme == "https" req2 = req.clone(scheme="http") assert req2.scheme == "http" assert req2.url.scheme == "http" def test_clone_absolute_host() -> None: req = make_mocked_request("GET", "https://example.com/path/to?a=1") assert req.host == "example.com" req2 = req.clone(host="foo.test") assert req2.host == "foo.test" assert req2.url.host == "foo.test" def test_content_length() -> None: req = make_mocked_request("Get", "/", CIMultiDict([("CONTENT-LENGTH", "123")])) assert 123 == req.content_length def test_range_to_slice_head() -> None: req = make_mocked_request( "GET", "/", headers=CIMultiDict([("RANGE", "bytes=0-499")]) ) assert isinstance(req.http_range, slice) assert req.http_range.start == 0 and req.http_range.stop == 500 def test_range_to_slice_mid() -> None: req = make_mocked_request( "GET", "/", headers=CIMultiDict([("RANGE", "bytes=500-999")]) ) assert isinstance(req.http_range, slice) assert req.http_range.start == 500 and req.http_range.stop == 1000 def test_range_to_slice_tail_start() -> None: req = make_mocked_request( "GET", "/", headers=CIMultiDict([("RANGE", "bytes=9500-")]) ) assert isinstance(req.http_range, slice) assert req.http_range.start == 9500 and req.http_range.stop is None def test_range_to_slice_tail_stop() -> None: req = make_mocked_request( "GET", "/", headers=CIMultiDict([("RANGE", "bytes=-500")]) ) assert isinstance(req.http_range, slice) assert req.http_range.start == -500 and req.http_range.stop is None def test_range_non_ascii() -> None: # ५ = DEVANAGARI DIGIT FIVE req = make_mocked_request("GET", "/", headers=CIMultiDict([("RANGE", "bytes=4-५")])) with pytest.raises(ValueError, match="range not in acceptable format"): req.http_range def test_non_keepalive_on_http10() -> None: req = make_mocked_request("GET", "/", version=HttpVersion(1, 0)) assert not req.keep_alive def test_non_keepalive_on_closing() -> None: req = make_mocked_request("GET", "/", closing=True) assert not req.keep_alive async def test_call_POST_on_GET_request() -> None: req = make_mocked_request("GET", "/") ret = await req.post() assert CIMultiDict() == ret async def test_call_POST_on_weird_content_type() -> None: req = make_mocked_request( "POST", "/", headers=CIMultiDict({"CONTENT-TYPE": "something/weird"}) ) ret = await req.post() assert CIMultiDict() == ret async def test_call_POST_twice() -> None: req = make_mocked_request("GET", "/") ret1 = await req.post() ret2 = await req.post() assert ret1 is ret2 def test_no_request_cookies() -> None: req = make_mocked_request("GET", "/") assert req.cookies == {} cookies = req.cookies assert cookies is req.cookies def test_request_cookie() -> None: headers = CIMultiDict(COOKIE="cookie1=value1; cookie2=value2") req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"cookie1": "value1", "cookie2": "value2"} def test_request_cookie__set_item() -> None: headers = CIMultiDict(COOKIE="name=value") req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"name": "value"} with pytest.raises(TypeError): req.cookies["my"] = "value" # type: ignore[index] def test_request_cookies_with_special_characters() -> None: """Test that cookies with special characters in names are accepted. This tests the fix for issue #2683 where cookies with special characters like {, }, / in their names would cause a 500 error. The fix makes the cookie parser more tolerant to handle real-world cookies. """ # Test cookie names with curly braces (e.g., ISAWPLB{DB45DF86-F806-407C-932C-D52A60E4019E}) headers = CIMultiDict(COOKIE="{test}=value1; normal=value2") req = make_mocked_request("GET", "/", headers=headers) # Both cookies should be parsed successfully assert req.cookies == {"{test}": "value1", "normal": "value2"} # Test cookie names with forward slash headers = CIMultiDict(COOKIE="test/name=value1; valid=value2") req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"test/name": "value1", "valid": "value2"} # Test cookie names with various special characters headers = CIMultiDict( COOKIE="test{foo}bar=value1; test/path=value2; normal_cookie=value3" ) req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == { "test{foo}bar": "value1", "test/path": "value2", "normal_cookie": "value3", } def test_request_cookies_real_world_examples() -> None: """Test handling of real-world cookie examples from issue #2683.""" # Example from the issue: ISAWPLB{DB45DF86-F806-407C-932C-D52A60E4019E} headers = CIMultiDict( COOKIE="ISAWPLB{DB45DF86-F806-407C-932C-D52A60E4019E}=val1; normal_cookie=val2" ) req = make_mocked_request("GET", "/", headers=headers) # All cookies should be parsed successfully assert req.cookies == { "ISAWPLB{DB45DF86-F806-407C-932C-D52A60E4019E}": "val1", "normal_cookie": "val2", } # Multiple cookies with special characters headers = CIMultiDict( COOKIE="{cookie1}=val1; cookie/2=val2; cookie[3]=val3; cookie(4)=val4" ) req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == { "{cookie1}": "val1", "cookie/2": "val2", "cookie[3]": "val3", "cookie(4)": "val4", } def test_request_cookies_edge_cases() -> None: """Test edge cases for cookie parsing.""" # Empty cookie value headers = CIMultiDict(COOKIE="test=; normal=value") req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"test": "", "normal": "value"} # Cookie with quoted value headers = CIMultiDict(COOKIE='test="quoted value"; normal=unquoted') req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"test": "quoted value", "normal": "unquoted"} def test_request_cookies_many_invalid(caplog: pytest.LogCaptureFixture) -> None: """Test many invalid cookies doesn't cause too many logs.""" bad = "bad" + chr(1) + "name" cookie = "; ".join(f"{bad}{i}=1" for i in range(3000)) req = make_mocked_request("GET", "/", headers=CIMultiDict(COOKIE=cookie)) with caplog.at_level(logging.DEBUG): cookies = req.cookies assert len(caplog.record_tuples) == 1 _, level, msg = caplog.record_tuples[0] assert level is logging.DEBUG assert "Cannot load cookie" in msg assert cookies == {} def test_request_cookies_no_500_error() -> None: """Test that cookies with special characters don't cause 500 errors. This specifically tests that issue #2683 is fixed - previously cookies with characters like { } would cause CookieError and 500 responses. """ # This cookie format previously caused 500 errors headers = CIMultiDict(COOKIE="ISAWPLB{DB45DF86-F806-407C-932C-D52A60E4019E}=test") # Should not raise any exception when accessing cookies req = make_mocked_request("GET", "/", headers=headers) cookies = req.cookies # This used to raise CookieError # Verify the cookie was parsed successfully assert "ISAWPLB{DB45DF86-F806-407C-932C-D52A60E4019E}" in cookies assert cookies["ISAWPLB{DB45DF86-F806-407C-932C-D52A60E4019E}"] == "test" def test_request_cookies_quoted_values() -> None: """Test that quoted cookie values are handled consistently. This tests the fix for issue #5397 where quoted cookie values were handled inconsistently based on whether domain attributes were present. The new parser should always unquote cookie values consistently. """ # Test simple quoted cookie value headers = CIMultiDict(COOKIE='sess="quoted_value"') req = make_mocked_request("GET", "/", headers=headers) # Quotes should be removed consistently assert req.cookies == {"sess": "quoted_value"} # Test quoted cookie with semicolon in value headers = CIMultiDict(COOKIE='data="value;with;semicolons"') req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"data": "value;with;semicolons"} # Test mixed quoted and unquoted cookies headers = CIMultiDict( COOKIE='quoted="value1"; unquoted=value2; also_quoted="value3"' ) req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == { "quoted": "value1", "unquoted": "value2", "also_quoted": "value3", } # Test escaped quotes in cookie value headers = CIMultiDict(COOKIE=r'escaped="value with \" quote"') req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"escaped": 'value with " quote'} # Test empty quoted value headers = CIMultiDict(COOKIE='empty=""') req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"empty": ""} def test_request_cookies_with_attributes() -> None: """Test that cookie attributes are parsed as cookies per RFC 6265. Per RFC 6265 Section 5.4, Cookie headers contain only name-value pairs. Names that match attribute names (Domain, Path, etc.) should be treated as regular cookies, not as attributes. """ # Cookie with domain - both should be parsed as cookies headers = CIMultiDict(COOKIE='sess="quoted_value"; Domain=.example.com') req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"sess": "quoted_value", "Domain": ".example.com"} # Cookie with multiple attribute names - all parsed as cookies headers = CIMultiDict(COOKIE='token="abc123"; Path=/; Secure; HttpOnly') req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == {"token": "abc123", "Path": "/", "Secure": "", "HttpOnly": ""} # Multiple cookies with attribute names mixed in headers = CIMultiDict( COOKIE='c1="v1"; Domain=.example.com; c2="v2"; Path=/api; c3=v3; Secure' ) req = make_mocked_request("GET", "/", headers=headers) assert req.cookies == { "c1": "v1", "Domain": ".example.com", "c2": "v2", "Path": "/api", "c3": "v3", "Secure": "", } def test_match_info() -> None: req = make_mocked_request("GET", "/") assert req._match_info is req.match_info def test_request_is_mutable_mapping() -> None: req = make_mocked_request("GET", "/") assert isinstance(req, MutableMapping) assert req # even when the MutableMapping is empty, request should always be True req["key"] = "value" assert "value" == req["key"] def test_request_delitem() -> None: req = make_mocked_request("GET", "/") req["key"] = "value" assert "value" == req["key"] del req["key"] assert "key" not in req def test_request_len() -> None: req = make_mocked_request("GET", "/") assert len(req) == 0 req["key"] = "value" assert len(req) == 1 def test_request_iter() -> None: req = make_mocked_request("GET", "/") req["key"] = "value" req["key2"] = "value2" key3 = web.RequestKey("key3", str) req[key3] = "value3" assert set(req) == {"key", "key2", key3} def test_requestkey() -> None: req = make_mocked_request("GET", "/") key = web.RequestKey("key", str) req[key] = "value" assert req[key] == "value" assert len(req) == 1 del req[key] assert len(req) == 0 def test_request_get_requestkey() -> None: req = make_mocked_request("GET", "/") key = web.RequestKey("key", int) assert req.get(key, "foo") == "foo" req[key] = 5 assert req.get(key, "foo") == 5 def test_requestkey_repr_concrete() -> None: key = web.RequestKey("key", int) assert repr(key) in ( "", # pytest-xdist "", ) key2 = web.RequestKey("key", web.Request) assert repr(key2) in ( # pytest-xdist: "", "", ) def test_requestkey_repr_nonconcrete() -> None: key = web.RequestKey("key", Iterator[int]) if sys.version_info < (3, 11): assert repr(key) in ( # pytest-xdist: "", "", ) else: assert repr(key) in ( # pytest-xdist: "", "", ) def test_requestkey_repr_annotated() -> None: key = web.RequestKey[Iterator[int]]("key") if sys.version_info < (3, 11): assert repr(key) in ( # pytest-xdist: "", "", ) else: assert repr(key) in ( # pytest-xdist: "", "", ) def test___repr__() -> None: req = make_mocked_request("GET", "/path/to") assert "" == repr(req) def test___repr___non_ascii_path() -> None: req = make_mocked_request("GET", "/path/\U0001f415\U0001f308") assert "" == repr(req) def test_http_scheme() -> None: req = make_mocked_request("GET", "/", headers={"Host": "example.com"}) assert "http" == req.scheme assert req.secure is False def test_https_scheme_by_ssl_transport() -> None: context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) req = make_mocked_request( "GET", "/", headers={"Host": "example.com"}, sslcontext=context ) assert "https" == req.scheme assert req.secure is True def test_single_forwarded_header() -> None: header = "by=identifier;for=identifier;host=identifier;proto=identifier" req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert req.forwarded[0]["by"] == "identifier" assert req.forwarded[0]["for"] == "identifier" assert req.forwarded[0]["host"] == "identifier" assert req.forwarded[0]["proto"] == "identifier" def test_forwarded_re_performance() -> None: FORWARDED_RE_TIME_THRESHOLD_SECONDS = 0.08 value = "{" + "f" * 54773 + "z\x00a=v" start = time.perf_counter() match = _FORWARDED_PAIR_RE.match(value) elapsed = time.perf_counter() - start # If this is taking more time, there's probably a performance/ReDoS issue. assert elapsed < FORWARDED_RE_TIME_THRESHOLD_SECONDS, ( f"Regex took {elapsed * 1000:.1f}ms, " f"expected <{FORWARDED_RE_TIME_THRESHOLD_SECONDS * 1000:.0f}ms - potential ReDoS issue" ) # This example shouldn't produce a match either. assert match is None @pytest.mark.parametrize( "forward_for_in, forward_for_out", [ ("1.2.3.4:1234", "1.2.3.4:1234"), ("1.2.3.4", "1.2.3.4"), ('"[2001:db8:cafe::17]:1234"', "[2001:db8:cafe::17]:1234"), ('"[2001:db8:cafe::17]"', "[2001:db8:cafe::17]"), ], ) def test_forwarded_node_identifier(forward_for_in: str, forward_for_out: str) -> None: header = f"for={forward_for_in}" req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert req.forwarded == ({"for": forward_for_out},) def test_single_forwarded_header_camelcase() -> None: header = "bY=identifier;fOr=identifier;HOst=identifier;pRoTO=identifier" req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert req.forwarded[0]["by"] == "identifier" assert req.forwarded[0]["for"] == "identifier" assert req.forwarded[0]["host"] == "identifier" assert req.forwarded[0]["proto"] == "identifier" def test_single_forwarded_header_single_param() -> None: header = "BY=identifier" req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert req.forwarded[0]["by"] == "identifier" def test_single_forwarded_header_multiple_param() -> None: header = "By=identifier1,BY=identifier2, By=identifier3 , BY=identifier4" req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert len(req.forwarded) == 4 assert req.forwarded[0]["by"] == "identifier1" assert req.forwarded[1]["by"] == "identifier2" assert req.forwarded[2]["by"] == "identifier3" assert req.forwarded[3]["by"] == "identifier4" def test_single_forwarded_header_quoted_escaped() -> None: header = r'BY=identifier;pROTO="\lala lan\d\~ 123\!&"' req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert req.forwarded[0]["by"] == "identifier" assert req.forwarded[0]["proto"] == "lala land~ 123!&" def test_single_forwarded_header_custom_param() -> None: header = r'BY=identifier;PROTO=https;SOME="other, \"value\""' req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert len(req.forwarded) == 1 assert req.forwarded[0]["by"] == "identifier" assert req.forwarded[0]["proto"] == "https" assert req.forwarded[0]["some"] == 'other, "value"' def test_single_forwarded_header_empty_params() -> None: # This is allowed by the grammar given in RFC 7239 header = ";For=identifier;;PROTO=https;;;" req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert req.forwarded[0]["for"] == "identifier" assert req.forwarded[0]["proto"] == "https" def test_single_forwarded_header_bad_separator() -> None: header = "BY=identifier PROTO=https" req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert "proto" not in req.forwarded[0] def test_single_forwarded_header_injection1() -> None: # We might receive a header like this if we're sitting behind a reverse # proxy that blindly appends a forwarded-element without checking # the syntax of existing field-values. We should be able to recover # the appended element anyway. header = 'for=_injected;by=", for=_real' req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert len(req.forwarded) == 2 assert "by" not in req.forwarded[0] assert req.forwarded[1]["for"] == "_real" def test_single_forwarded_header_injection2() -> None: header = "very bad syntax, for=_real" req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert len(req.forwarded) == 2 assert "for" not in req.forwarded[0] assert req.forwarded[1]["for"] == "_real" def test_single_forwarded_header_long_quoted_string() -> None: header = 'for="' + "\\\\" * 5000 + '"' req = make_mocked_request("GET", "/", headers=CIMultiDict({"Forwarded": header})) assert req.forwarded[0]["for"] == "\\" * 5000 def test_multiple_forwarded_headers() -> None: headers = CIMultiDict[str]() headers.add("Forwarded", "By=identifier1;for=identifier2, BY=identifier3") headers.add("Forwarded", "By=identifier4;fOr=identifier5") req = make_mocked_request("GET", "/", headers=headers) assert len(req.forwarded) == 3 assert req.forwarded[0]["by"] == "identifier1" assert req.forwarded[0]["for"] == "identifier2" assert req.forwarded[1]["by"] == "identifier3" assert req.forwarded[2]["by"] == "identifier4" assert req.forwarded[2]["for"] == "identifier5" def test_multiple_forwarded_headers_bad_syntax() -> None: headers = CIMultiDict[str]() headers.add("Forwarded", "for=_1;by=_2") headers.add("Forwarded", "invalid value") headers.add("Forwarded", "") headers.add("Forwarded", "for=_3;by=_4") req = make_mocked_request("GET", "/", headers=headers) assert len(req.forwarded) == 4 assert req.forwarded[0]["for"] == "_1" assert "for" not in req.forwarded[1] assert "for" not in req.forwarded[2] assert req.forwarded[3]["by"] == "_4" def test_multiple_forwarded_headers_injection() -> None: headers = CIMultiDict[str]() # This could be sent by an attacker, hoping to "shadow" the second header. headers.add("Forwarded", 'for=_injected;by="') # This is added by our trusted reverse proxy. headers.add("Forwarded", "for=_real;by=_actual_proxy") req = make_mocked_request("GET", "/", headers=headers) assert len(req.forwarded) == 2 assert "by" not in req.forwarded[0] assert req.forwarded[1]["for"] == "_real" assert req.forwarded[1]["by"] == "_actual_proxy" def test_host_by_host_header() -> None: req = make_mocked_request("GET", "/", headers=CIMultiDict({"Host": "example.com"})) assert req.host == "example.com" def test_raw_headers() -> None: req = make_mocked_request("GET", "/", headers=CIMultiDict({"X-HEADER": "aaa"})) assert req.raw_headers == ((b"X-HEADER", b"aaa"),) def test_rel_url() -> None: req = make_mocked_request("GET", "/path") assert URL("/path") == req.rel_url def test_url_url() -> None: req = make_mocked_request("GET", "/path", headers={"HOST": "example.com"}) assert URL("http://example.com/path") == req.url def test_url_non_default_port() -> None: req = make_mocked_request("GET", "/path", headers={"HOST": "example.com:8123"}) assert req.url == URL("http://example.com:8123/path") def test_url_ipv6() -> None: req = make_mocked_request("GET", "/path", headers={"HOST": "[::1]:8123"}) assert req.url == URL("http://[::1]:8123/path") def test_clone() -> None: req = make_mocked_request("GET", "/path") req2 = req.clone() assert req2.method == "GET" assert req2.rel_url == URL("/path") def test_clone_client_max_size() -> None: req = make_mocked_request("GET", "/path", client_max_size=1024) req2 = req.clone() assert req._client_max_size == req2._client_max_size assert req2._client_max_size == 1024 def test_clone_override_client_max_size() -> None: req = make_mocked_request("GET", "/path", client_max_size=1024) req2 = req.clone(client_max_size=2048) assert req2.client_max_size == 2048 def test_clone_method() -> None: req = make_mocked_request("GET", "/path") req2 = req.clone(method="POST") assert req2.method == "POST" assert req2.rel_url == URL("/path") def test_clone_rel_url() -> None: req = make_mocked_request("GET", "/path") req2 = req.clone(rel_url=URL("/path2")) assert req2.rel_url == URL("/path2") def test_clone_rel_url_str() -> None: req = make_mocked_request("GET", "/path") req2 = req.clone(rel_url="/path2") assert req2.rel_url == URL("/path2") def test_clone_headers() -> None: req = make_mocked_request("GET", "/path", headers={"A": "B"}) req2 = req.clone(headers=CIMultiDict({"B": "C"})) assert req2.headers == CIMultiDict({"B": "C"}) assert req2.raw_headers == ((b"B", b"C"),) def test_clone_headers_dict() -> None: req = make_mocked_request("GET", "/path", headers={"A": "B"}) req2 = req.clone(headers={"B": "C"}) assert req2.headers == CIMultiDict({"B": "C"}) assert req2.raw_headers == ((b"B", b"C"),) async def test_cannot_clone_after_read(protocol: BaseProtocol) -> None: payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) payload.feed_data(b"data") payload.feed_eof() req = make_mocked_request("GET", "/path", payload=payload) await req.read() with pytest.raises(RuntimeError): req.clone() async def test_make_too_big_request(protocol: BaseProtocol) -> None: payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) large_file = 1024**2 * b"x" too_large_file = large_file + b"x" payload.feed_data(too_large_file) payload.feed_eof() req = make_mocked_request("POST", "/", payload=payload) with pytest.raises(web.HTTPRequestEntityTooLarge) as err: await req.read() assert err.value.status_code == 413 async def test_request_with_wrong_content_type_encoding(protocol: BaseProtocol) -> None: payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) payload.feed_data(b"{}") payload.feed_eof() headers = {"Content-Type": "text/html; charset=test"} req = make_mocked_request("POST", "/", payload=payload, headers=headers) with pytest.raises(web.HTTPUnsupportedMediaType) as err: await req.text() assert err.value.status_code == 415 async def test_make_too_big_request_same_size_to_max(protocol: BaseProtocol) -> None: payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) large_file = 1024**2 * b"x" payload.feed_data(large_file) payload.feed_eof() req = make_mocked_request("POST", "/", payload=payload) resp_text = await req.read() assert resp_text == large_file async def test_make_too_big_request_adjust_limit(protocol: BaseProtocol) -> None: payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) large_file = 1024**2 * b"x" too_large_file = large_file + b"x" payload.feed_data(too_large_file) payload.feed_eof() max_size = 1024**2 + 2 req = make_mocked_request("POST", "/", payload=payload, client_max_size=max_size) txt = await req.read() assert len(txt) == 1024**2 + 1 async def test_multipart_formdata(protocol: BaseProtocol) -> None: payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) payload.feed_data( b"-----------------------------326931944431359\r\n" b'Content-Disposition: form-data; name="a"\r\n' b"\r\n" b"b\r\n" b"-----------------------------326931944431359\r\n" b'Content-Disposition: form-data; name="c"\r\n' b"\r\n" b"d\r\n" b"-----------------------------326931944431359--\r\n" ) content_type = ( "multipart/form-data; boundary=---------------------------326931944431359" ) payload.feed_eof() req = make_mocked_request( "POST", "/", headers={"CONTENT-TYPE": content_type}, payload=payload ) result = await req.post() assert dict(result) == {"a": "b", "c": "d"} async def test_multipart_formdata_field_missing_name(protocol: BaseProtocol) -> None: # Ensure ValueError is raised when Content-Disposition has no name payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) payload.feed_data( b"-----------------------------326931944431359\r\n" b"Content-Disposition: form-data\r\n" # Missing name! b"\r\n" b"value\r\n" b"-----------------------------326931944431359--\r\n" ) content_type = ( "multipart/form-data; boundary=---------------------------326931944431359" ) payload.feed_eof() req = make_mocked_request( "POST", "/", headers={"CONTENT-TYPE": content_type}, payload=payload ) with pytest.raises(ValueError, match="Multipart field missing name"): await req.post() async def test_multipart_formdata_file(protocol: BaseProtocol) -> None: # Make sure file uploads work, even without a content type payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) payload.feed_data( b"-----------------------------326931944431359\r\n" b'Content-Disposition: form-data; name="a_file"; filename="binary"\r\n' b"\r\n" b"\ff\r\n" b"-----------------------------326931944431359--\r\n" ) content_type = ( "multipart/form-data; boundary=---------------------------326931944431359" ) payload.feed_eof() req = make_mocked_request( "POST", "/", headers={"CONTENT-TYPE": content_type}, payload=payload ) result = await req.post() assert hasattr(result["a_file"], "file") content = result["a_file"].file.read() assert content == b"\ff" req._finish() async def test_multipart_formdata_headers_too_many(protocol: BaseProtocol) -> None: many = b"".join(f"X-{i}: a\r\n".encode() for i in range(130)) body = ( b"--b\r\n" b'Content-Disposition: form-data; name="a"\r\n' + many + b"\r\n1\r\n" b"--b--\r\n" ) content_type = "multipart/form-data; boundary=b" payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) payload.feed_data(body) payload.feed_eof() req = make_mocked_request( "POST", "/", headers={"CONTENT-TYPE": content_type}, payload=payload, ) with pytest.raises(BadHttpMessage, match="Too many headers received"): await req.post() async def test_multipart_formdata_header_too_long(protocol: BaseProtocol) -> None: k = b"t" * 4100 body = ( b"--b\r\n" b'Content-Disposition: form-data; name="a"\r\n' + k + b":" + k + b"\r\n" + b"\r\n1\r\n" b"--b--\r\n" ) content_type = "multipart/form-data; boundary=b" payload = StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) payload.feed_data(body) payload.feed_eof() req = make_mocked_request( "POST", "/", headers={"CONTENT-TYPE": content_type}, payload=payload, ) match = "400, message:\n Got more than 8190 bytes when reading" with pytest.raises(LineTooLong, match=match): await req.post() async def test_make_too_big_request_limit_None(protocol: BaseProtocol) -> None: payload = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) large_file = 1024**2 * b"x" too_large_file = large_file + b"x" payload.feed_data(too_large_file) payload.feed_eof() req = make_mocked_request("POST", "/", payload=payload, client_max_size=0) txt = await req.read() assert len(txt) == 1024**2 + 1 def test_remote_peername_tcp() -> None: transp = mock.Mock() transp.get_extra_info.return_value = ("10.10.10.10", 1234) req = make_mocked_request("GET", "/", transport=transp) assert req.remote == "10.10.10.10" def test_remote_peername_unix() -> None: transp = mock.Mock() transp.get_extra_info.return_value = "/path/to/sock" req = make_mocked_request("GET", "/", transport=transp) assert req.remote == "/path/to/sock" def test_save_state_on_clone() -> None: req = make_mocked_request("GET", "/") req["key"] = "val" req2 = req.clone() req2["key"] = "val2" assert req["key"] == "val" assert req2["key"] == "val2" def test_clone_scheme() -> None: req = make_mocked_request("GET", "/") assert req.scheme == "http" req2 = req.clone(scheme="https") assert req2.scheme == "https" assert req2.url.scheme == "https" def test_clone_host() -> None: req = make_mocked_request("GET", "/") assert req.host != "example.com" req2 = req.clone(host="example.com") assert req2.host == "example.com" assert req2.url.host == "example.com" def test_clone_remote() -> None: req = make_mocked_request("GET", "/") assert req.remote != "11.11.11.11" req2 = req.clone(remote="11.11.11.11") assert req2.remote == "11.11.11.11" def test_remote_with_closed_transport() -> None: transp = mock.Mock() transp.get_extra_info.return_value = ("10.10.10.10", 1234) req = make_mocked_request("GET", "/", transport=transp) req._protocol = None # type: ignore[assignment] assert req.remote == "10.10.10.10" def test_url_http_with_closed_transport() -> None: req = make_mocked_request("GET", "/") req._protocol = None # type: ignore[assignment] assert str(req.url).startswith("http://") def test_url_https_with_closed_transport() -> None: c = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) req = make_mocked_request("GET", "/", sslcontext=c) req._protocol = None # type: ignore[assignment] assert str(req.url).startswith("https://") async def test_get_extra_info() -> None: valid_key = "test" valid_value = "existent" default_value = "default" def get_extra_info(name: str, default: object = None) -> object: return {valid_key: valid_value}.get(name, default) transp = mock.Mock() transp.get_extra_info.side_effect = get_extra_info req = make_mocked_request("GET", "/", transport=transp) assert req is not None req_extra_info = req.get_extra_info(valid_key, default_value) assert req._protocol.transport is not None transp_extra_info = req._protocol.transport.get_extra_info(valid_key, default_value) assert req_extra_info == transp_extra_info req._protocol.transport = None extra_info = req.get_extra_info(valid_key, default_value) assert extra_info == default_value def test_eq() -> None: req1 = make_mocked_request("GET", "/path/to?a=1&b=2") req2 = make_mocked_request("GET", "/path/to?a=1&b=2") assert req1 != req2 assert req1 == req1 async def test_json(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: body_text = await request.text() assert body_text == '{"some": "data"}' assert request.headers["Content-Type"] == "application/json" body_json = await request.json() assert body_json == {"some": "data"} return web.Response() app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) json_data = {"some": "data"} async with client.post("/", json=json_data) as resp: assert 200 == resp.status async def test_json_invalid_content_type(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> NoReturn: body_text = await request.text() assert body_text == '{"some": "data"}' assert request.headers["Content-Type"] == "text/plain" await request.json() # raises HTTP 400 assert False app = web.Application() app.router.add_post("/", handler) client = await aiohttp_client(app) json_data = {"some": "data"} headers = {"Content-Type": "text/plain"} async with client.post("/", json=json_data, headers=headers) as resp: assert 400 == resp.status resp_text = await resp.text() assert resp_text == ( "Attempt to decode JSON with unexpected mimetype: text/plain" ) def test_weakref_creation() -> None: req = make_mocked_request("GET", "/") weakref.ref(req) @pytest.mark.parametrize( ("header", "header_attr"), ( pytest.param("If-Match", "if_match"), pytest.param("If-None-Match", "if_none_match"), ), ) @pytest.mark.parametrize( ("header_val", "expected"), ( pytest.param( '"67ab43", W/"54ed21", "7892,dd"', ( ETag(is_weak=False, value="67ab43"), ETag(is_weak=True, value="54ed21"), ETag(is_weak=False, value="7892,dd"), ), ), pytest.param( '"bfc1ef-5b2c2730249c88ca92d82d"', (ETag(is_weak=False, value="bfc1ef-5b2c2730249c88ca92d82d"),), ), pytest.param( '"valid-tag", "also-valid-tag",somegarbage"last-tag"', ( ETag(is_weak=False, value="valid-tag"), ETag(is_weak=False, value="also-valid-tag"), ), ), pytest.param( '"ascii", "это точно не ascii", "ascii again"', (ETag(is_weak=False, value="ascii"),), ), pytest.param( "*", (ETag(is_weak=False, value="*"),), ), ), ) def test_etag_headers( header: str, header_attr: str, header_val: str, expected: tuple[ETag, ...] ) -> None: req = make_mocked_request("GET", "/", headers={header: header_val}) assert getattr(req, header_attr) == expected @pytest.mark.parametrize( ("header", "header_attr"), ( pytest.param("If-Modified-Since", "if_modified_since"), pytest.param("If-Unmodified-Since", "if_unmodified_since"), pytest.param("If-Range", "if_range"), ), ) @pytest.mark.parametrize( ("header_val", "expected"), ( pytest.param("xxyyzz", None), pytest.param("Tue, 08 Oct 4446413 00:56:40 GMT", None), pytest.param("Tue, 08 Oct 2000 00:56:80 GMT", None), pytest.param( "Tue, 08 Oct 2000 00:56:40 GMT", datetime.datetime(2000, 10, 8, 0, 56, 40, tzinfo=datetime.timezone.utc), ), ), ) def test_datetime_headers( header: str, header_attr: str, header_val: str, expected: datetime.datetime | None, ) -> None: req = make_mocked_request("GET", "/", headers={header: header_val}) assert getattr(req, header_attr) == expected ================================================ FILE: tests/test_web_request_handler.py ================================================ from unittest import mock from aiohttp import web async def serve(request: web.BaseRequest) -> web.Response: assert False async def test_repr() -> None: manager = web.Server(serve) handler = manager() assert "" == repr(handler) with mock.patch.object(handler, "transport", autospec=True): assert "" == repr(handler) async def test_connections() -> None: manager = web.Server(serve) assert manager.connections == [] handler = mock.Mock(spec_set=web.RequestHandler) handler._task_handler = None transport = object() manager.connection_made(handler, transport) # type: ignore[arg-type] assert manager.connections == [handler] manager.connection_lost(handler, None) assert manager.connections == [] async def test_shutdown_no_timeout() -> None: manager = web.Server(serve) handler = mock.Mock(spec_set=web.RequestHandler) handler._task_handler = None handler.shutdown = mock.AsyncMock(return_value=mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport) await manager.shutdown() manager.connection_lost(handler, None) assert manager.connections == [] handler.shutdown.assert_called_with(None) async def test_shutdown_timeout() -> None: manager = web.Server(serve) handler = mock.Mock() handler.shutdown = mock.AsyncMock(return_value=mock.Mock()) transport = mock.Mock() manager.connection_made(handler, transport) await manager.shutdown(timeout=0.1) manager.connection_lost(handler, None) assert manager.connections == [] handler.shutdown.assert_called_with(0.1) ================================================ FILE: tests/test_web_response.py ================================================ import collections.abc import datetime import gzip import io import json import re import sys import weakref from collections.abc import AsyncIterator, Iterator from concurrent.futures import ThreadPoolExecutor from unittest import mock import aiosignal import pytest from multidict import CIMultiDict, CIMultiDictProxy, MultiDict from aiohttp import HttpVersion, HttpVersion10, HttpVersion11, hdrs, web from aiohttp.abc import AbstractStreamWriter from aiohttp.helpers import ETag from aiohttp.http_writer import StreamWriter, _serialize_headers from aiohttp.multipart import BodyPartReader, MultipartWriter from aiohttp.payload import BytesPayload, StringPayload from aiohttp.test_utils import make_mocked_request from aiohttp.typedefs import LooseHeaders def make_request( method: str, path: str, headers: LooseHeaders = CIMultiDict(), version: HttpVersion = HttpVersion11, *, app: web.Application | None = None, writer: AbstractStreamWriter | None = None, ) -> web.Request: if app is None: app = mock.create_autospec( web.Application, spec_set=True, on_response_prepare=aiosignal.Signal(app) ) app.on_response_prepare.freeze() return make_mocked_request( method, path, headers, version=version, app=app, writer=writer ) @pytest.fixture def buf() -> bytearray: return bytearray() @pytest.fixture def writer(buf: bytearray) -> AbstractStreamWriter: writer = mock.create_autospec(AbstractStreamWriter, spec_set=True) async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: b_headers = _serialize_headers(status_line, headers) buf.extend(b_headers) async def write_eof(chunk: bytes = b"") -> None: buf.extend(chunk) writer.write_eof.side_effect = write_eof writer.write_headers.side_effect = write_headers return writer # type: ignore[no-any-return] def test_stream_response_ctor() -> None: resp = web.StreamResponse() assert 200 == resp.status assert resp.keep_alive is None assert resp.task is None req = mock.Mock() resp._req = req assert resp.task is req.task def test_stream_response_hashable() -> None: # should not raise exception hash(web.StreamResponse()) def test_stream_response_eq() -> None: resp1 = web.StreamResponse() resp2 = web.StreamResponse() assert resp1 == resp1 assert not resp1 == resp2 def test_stream_response_is_mutable_mapping() -> None: resp = web.StreamResponse() assert isinstance(resp, collections.abc.MutableMapping) assert resp # even when the MutableMapping is empty, response should always be True resp["key"] = "value" assert "value" == resp["key"] def test_stream_response_delitem() -> None: resp = web.StreamResponse() resp["key"] = "value" del resp["key"] assert "key" not in resp def test_stream_response_len() -> None: resp = web.StreamResponse() assert len(resp) == 0 resp["key"] = "value" assert len(resp) == 1 def test_response_iter() -> None: resp = web.StreamResponse() resp["key"] = "value" resp["key2"] = "value2" key3 = web.ResponseKey("key3", str) resp[key3] = "value3" assert set(resp) == {"key", "key2", key3} def test_responsekey() -> None: resp = web.StreamResponse() key = web.ResponseKey("key", str) resp[key] = "value" assert resp[key] == "value" assert len(resp) == 1 del resp[key] assert len(resp) == 0 def test_response_get_responsekey() -> None: resp = web.StreamResponse() key = web.ResponseKey("key", int) assert resp.get(key, "foo") == "foo" resp[key] = 5 assert resp.get(key, "foo") == 5 def test_responsekey_repr_concrete() -> None: key = web.ResponseKey("key", int) assert repr(key) in ( "", # pytest-xdist "", ) key2 = web.ResponseKey("key", web.Request) assert repr(key2) in ( # pytest-xdist: "", "", ) def test_responsekey_repr_nonconcrete() -> None: key = web.ResponseKey("key", Iterator[int]) if sys.version_info < (3, 11): assert repr(key) in ( # pytest-xdist: "", "", ) else: assert repr(key) in ( # pytest-xdist: "", "", ) def test_responsekey_repr_annotated() -> None: key = web.ResponseKey[Iterator[int]]("key") if sys.version_info < (3, 11): assert repr(key) in ( # pytest-xdist: "", "", ) else: assert repr(key) in ( # pytest-xdist: "", "", ) def test_content_length() -> None: resp = web.StreamResponse() assert resp.content_length is None def test_content_length_setter() -> None: resp = web.StreamResponse() resp.content_length = 234 assert 234 == resp.content_length def test_content_length_setter_with_enable_chunked_encoding() -> None: resp = web.StreamResponse() resp.enable_chunked_encoding() with pytest.raises(RuntimeError): resp.content_length = 234 def test_drop_content_length_header_on_setting_len_to_None() -> None: resp = web.StreamResponse() resp.content_length = 1 assert "1" == resp.headers["Content-Length"] resp.content_length = None assert "Content-Length" not in resp.headers def test_set_content_length_to_None_on_non_set() -> None: resp = web.StreamResponse() resp.content_length = None assert "Content-Length" not in resp.headers resp.content_length = None assert "Content-Length" not in resp.headers def test_setting_content_type() -> None: resp = web.StreamResponse() resp.content_type = "text/html" assert "text/html" == resp.headers["content-type"] def test_setting_charset() -> None: resp = web.StreamResponse() resp.content_type = "text/html" resp.charset = "koi8-r" assert "text/html; charset=koi8-r" == resp.headers["content-type"] def test_default_charset() -> None: resp = web.StreamResponse() assert resp.charset is None def test_reset_charset() -> None: resp = web.StreamResponse() resp.content_type = "text/html" resp.charset = None assert resp.charset is None def test_reset_charset_after_setting() -> None: resp = web.StreamResponse() resp.content_type = "text/html" resp.charset = "koi8-r" resp.charset = None assert resp.charset is None def test_charset_without_content_type() -> None: resp = web.StreamResponse() with pytest.raises(RuntimeError): resp.charset = "koi8-r" def test_last_modified_initial() -> None: resp = web.StreamResponse() assert resp.last_modified is None def test_last_modified_string() -> None: resp = web.StreamResponse() dt = datetime.datetime(1990, 1, 2, 3, 4, 5, 0, datetime.timezone.utc) resp.last_modified = "Mon, 2 Jan 1990 03:04:05 GMT" assert resp.last_modified == dt def test_last_modified_timestamp() -> None: resp = web.StreamResponse() dt = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, datetime.timezone.utc) resp.last_modified = 0 assert resp.last_modified == dt resp.last_modified = 0.0 assert resp.last_modified == dt def test_last_modified_datetime() -> None: resp = web.StreamResponse() dt = datetime.datetime(2001, 2, 3, 4, 5, 6, 0, datetime.timezone.utc) resp.last_modified = dt assert resp.last_modified == dt def test_last_modified_reset() -> None: resp = web.StreamResponse() resp.last_modified = 0 resp.last_modified = None assert resp.last_modified is None def test_last_modified_invalid_type() -> None: resp = web.StreamResponse() with pytest.raises(TypeError, match="Unsupported type for last_modified: object"): resp.last_modified = object() # type: ignore[assignment] @pytest.mark.parametrize( "header_val", ( "xxyyzz", "Tue, 08 Oct 4446413 00:56:40 GMT", "Tue, 08 Oct 2000 00:56:80 GMT", ), ) def test_last_modified_string_invalid(header_val: str) -> None: resp = web.StreamResponse(headers={"Last-Modified": header_val}) assert resp.last_modified is None def test_etag_initial() -> None: resp = web.StreamResponse() assert resp.etag is None def test_etag_string() -> None: resp = web.StreamResponse() value = "0123-kotik" resp.etag = value assert resp.etag == ETag(value=value) assert resp.headers[hdrs.ETAG] == f'"{value}"' @pytest.mark.parametrize( ("etag", "expected_header"), ( (ETag(value="0123-weak-kotik", is_weak=True), 'W/"0123-weak-kotik"'), (ETag(value="0123-strong-kotik", is_weak=False), '"0123-strong-kotik"'), ), ) def test_etag_class(etag: ETag, expected_header: str) -> None: resp = web.StreamResponse() resp.etag = etag assert resp.etag == etag assert resp.headers[hdrs.ETAG] == expected_header def test_etag_any() -> None: resp = web.StreamResponse() resp.etag = "*" assert resp.etag == ETag(value="*") assert resp.headers[hdrs.ETAG] == "*" @pytest.mark.parametrize( "invalid_value", ( '"invalid"', "повинен бути ascii", ETag(value='"invalid"', is_weak=True), ETag(value="bad ©®"), ), ) def test_etag_invalid_value_set(invalid_value: str | ETag) -> None: resp = web.StreamResponse() with pytest.raises(ValueError, match="is not a valid etag"): resp.etag = invalid_value @pytest.mark.parametrize( "header", ( "forgotten quotes", '"∀ x ∉ ascii"', ), ) def test_etag_invalid_value_get(header: str) -> None: resp = web.StreamResponse() resp.headers["ETag"] = header assert resp.etag is None @pytest.mark.parametrize("invalid", (123, ETag(value=123, is_weak=True))) # type: ignore[arg-type] def test_etag_invalid_value_class(invalid: int | ETag) -> None: resp = web.StreamResponse() with pytest.raises(ValueError, match="Unsupported etag type"): resp.etag = invalid # type: ignore[assignment] def test_etag_reset() -> None: resp = web.StreamResponse() resp.etag = "*" resp.etag = None assert resp.etag is None async def test_start() -> None: req = make_request("GET", "/") resp = web.StreamResponse() assert resp.keep_alive is None msg = await resp.prepare(req) assert msg is not None assert msg.write_headers.called # type: ignore[attr-defined] msg2 = await resp.prepare(req) assert msg is msg2 assert resp.keep_alive req2 = make_request("GET", "/") # type: ignore[unreachable] # with pytest.raises(RuntimeError): msg3 = await resp.prepare(req2) assert msg is msg3 async def test_chunked_encoding() -> None: req = make_request("GET", "/") resp = web.StreamResponse() assert not resp.chunked resp.enable_chunked_encoding() assert resp.chunked msg = await resp.prepare(req) # type: ignore[unreachable] assert msg.chunked def test_enable_chunked_encoding_with_content_length() -> None: resp = web.StreamResponse() resp.content_length = 234 with pytest.raises(RuntimeError): resp.enable_chunked_encoding() async def test_chunked_encoding_forbidden_for_http_10() -> None: req = make_request("GET", "/", version=HttpVersion10) resp = web.StreamResponse() resp.enable_chunked_encoding() with pytest.raises(RuntimeError) as ctx: await resp.prepare(req) assert str(ctx.value) == "Using chunked encoding is forbidden for HTTP/1.0" @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_compression_no_accept() -> None: req = make_request("GET", "/") resp = web.StreamResponse() assert not resp.chunked assert not resp.compression resp.enable_compression() assert resp.compression msg = await resp.prepare(req) # type: ignore[unreachable] assert not msg.enable_compression.called @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_compression_default_coding() -> None: req = make_request( "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) ) resp = web.StreamResponse() assert not resp.chunked assert not resp.compression resp.enable_compression() assert resp.compression msg = await resp.prepare(req) # type: ignore[unreachable] msg.enable_compression.assert_called_with("deflate", None) assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) assert msg.filter is not None @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_deflate() -> None: req = make_request( "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) ) resp = web.StreamResponse() resp.enable_compression(web.ContentCoding.deflate) assert resp.compression msg = await resp.prepare(req) assert msg is not None msg.enable_compression.assert_called_with("deflate", None) # type: ignore[attr-defined] assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_deflate_large_payload() -> None: """Make sure a warning is thrown for large payloads compressed in the event loop.""" req = make_request( "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) ) resp = web.Response(body=b"large") resp.enable_compression(web.ContentCoding.deflate) assert resp.compression with ( pytest.warns(Warning, match="Synchronous compression of large response bodies"), mock.patch("aiohttp.web_response.LARGE_BODY_SIZE", 2), ): msg = await resp.prepare(req) assert msg is not None assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_no_accept_deflate() -> None: req = make_request("GET", "/") resp = web.StreamResponse() resp.enable_compression(web.ContentCoding.deflate) assert resp.compression msg = await resp.prepare(req) assert msg is not None msg.enable_compression.assert_called_with("deflate", None) # type: ignore[attr-defined] assert "deflate" == resp.headers.get(hdrs.CONTENT_ENCODING) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_gzip() -> None: req = make_request( "GET", "/", headers=CIMultiDict({hdrs.ACCEPT_ENCODING: "gzip, deflate"}) ) resp = web.StreamResponse() resp.enable_compression(web.ContentCoding.gzip) assert resp.compression msg = await resp.prepare(req) assert msg is not None msg.enable_compression.assert_called_with("gzip", None) # type: ignore[attr-defined] assert "gzip" == resp.headers.get(hdrs.CONTENT_ENCODING) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_no_accept_gzip() -> None: req = make_request("GET", "/") resp = web.StreamResponse() resp.enable_compression(web.ContentCoding.gzip) assert resp.compression msg = await resp.prepare(req) assert msg is not None msg.enable_compression.assert_called_with("gzip", None) # type: ignore[attr-defined] assert "gzip" == resp.headers.get(hdrs.CONTENT_ENCODING) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_change_content_threaded_compression_enabled() -> None: req = make_request("GET", "/") body_thread_size = 1024 body = b"answer" * body_thread_size resp = web.Response(body=body, zlib_executor_size=body_thread_size) resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(req) assert resp._compressed_body is not None assert gzip.decompress(resp._compressed_body) == body @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_change_content_threaded_compression_enabled_explicit() -> None: req = make_request("GET", "/") body_thread_size = 1024 body = b"answer" * body_thread_size with ThreadPoolExecutor(1) as executor: resp = web.Response( body=body, zlib_executor_size=body_thread_size, zlib_executor=executor ) resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(req) assert resp._compressed_body is not None assert gzip.decompress(resp._compressed_body) == body @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_change_content_length_if_compression_enabled() -> None: req = make_request("GET", "/") resp = web.Response(body=b"answer") resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(req) assert resp.content_length is not None and resp.content_length != len(b"answer") @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_set_content_length_if_compression_enabled() -> None: writer = mock.Mock() async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert hdrs.CONTENT_LENGTH in headers assert headers[hdrs.CONTENT_LENGTH] == "26" assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request("GET", "/", writer=writer) resp = web.Response(body=b"answer") resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(req) assert resp.content_length == 26 del resp.headers[hdrs.CONTENT_LENGTH] assert resp.content_length == 26 @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_remove_content_length_if_compression_enabled_http11() -> None: writer = mock.Mock() async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert hdrs.CONTENT_LENGTH not in headers assert headers.get(hdrs.TRANSFER_ENCODING, "") == "chunked" writer.write_headers.side_effect = write_headers req = make_request("GET", "/", writer=writer) resp = web.StreamResponse() resp.content_length = 123 resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(req) assert resp.content_length is None @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_remove_content_length_if_compression_enabled_http10() -> None: writer = mock.Mock() async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert hdrs.CONTENT_LENGTH not in headers assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request("GET", "/", version=HttpVersion10, writer=writer) resp = web.StreamResponse() resp.content_length = 123 resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(req) assert resp.content_length is None @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_identity() -> None: writer = mock.Mock() async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert hdrs.CONTENT_LENGTH in headers assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request("GET", "/", writer=writer) resp = web.StreamResponse() resp.content_length = 123 resp.enable_compression(web.ContentCoding.identity) await resp.prepare(req) assert resp.content_length == 123 @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_force_compression_identity_response() -> None: writer = mock.Mock() async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert headers[hdrs.CONTENT_LENGTH] == "6" assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request("GET", "/", writer=writer) resp = web.Response(body=b"answer") resp.enable_compression(web.ContentCoding.identity) await resp.prepare(req) assert resp.content_length == 6 async def test_enable_compression_with_existing_encoding() -> None: """Test that enable_compression does not override existing content encoding.""" writer = mock.Mock() async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: # Should preserve the existing content encoding assert headers[hdrs.CONTENT_ENCODING] == "gzip" # Should not have double encoding assert headers.get(hdrs.CONTENT_ENCODING) != "gzip, deflate" writer.write_headers.side_effect = write_headers req = make_request("GET", "/", writer=writer) resp = web.Response(body=b"answer") # Manually set content encoding (simulating FileResponse with pre-compressed file) resp.headers[hdrs.CONTENT_ENCODING] = "gzip" # Try to enable compression - should be ignored resp.enable_compression(web.ContentCoding.deflate) await resp.prepare(req) # Verify compression was not enabled due to existing encoding assert not resp.compression @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_rm_content_length_if_compression_http11() -> None: writer = mock.Mock() async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert hdrs.CONTENT_LENGTH not in headers assert headers.get(hdrs.TRANSFER_ENCODING, "") == "chunked" writer.write_headers.side_effect = write_headers req = make_request("GET", "/", writer=writer) payload = BytesPayload(b"answer", headers={"X-Test-Header": "test"}) resp = web.Response(body=payload) resp.body = payload resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(req) assert resp.content_length is None @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_rm_content_length_if_compression_http10() -> None: writer = mock.Mock() async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert hdrs.CONTENT_LENGTH not in headers assert hdrs.TRANSFER_ENCODING not in headers writer.write_headers.side_effect = write_headers req = make_request("GET", "/", version=HttpVersion10, writer=writer) resp = web.Response(body=BytesPayload(b"answer")) resp.enable_compression(web.ContentCoding.gzip) await resp.prepare(req) assert resp.content_length is None async def test_rm_content_length_if_204() -> None: """Ensure content-length is removed for 204 responses.""" writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) async def write_headers(status_line: str, headers: CIMultiDict[str]) -> None: assert hdrs.CONTENT_LENGTH not in headers writer.write_headers.side_effect = write_headers req = make_request("GET", "/", writer=writer) payload = BytesPayload(b"answer", headers={"Content-Length": "6"}) resp = web.Response(body=payload, status=204) resp.body = payload await resp.prepare(req) assert resp.content_length is None @pytest.mark.parametrize("status", (100, 101, 204, 304)) async def test_rm_transfer_encoding_rfc_9112_6_3_http_11(status: int) -> None: """Remove transfer encoding for RFC 9112 sec 6.3 with HTTP/1.1.""" writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", version=HttpVersion11, writer=writer) resp = web.Response(status=status, headers={hdrs.TRANSFER_ENCODING: "chunked"}) await resp.prepare(req) assert resp.content_length == 0 assert not resp.chunked assert hdrs.CONTENT_LENGTH not in resp.headers assert hdrs.TRANSFER_ENCODING not in resp.headers @pytest.mark.parametrize("status", (100, 101, 102, 204, 304)) async def test_rm_content_length_1xx_204_304_responses(status: int) -> None: """Remove content length for 1xx, 204, and 304 responses. Content-Length is forbidden for 1xx and 204 https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2 Content-Length is discouraged for 304. https://datatracker.ietf.org/doc/html/rfc7232#section-4.1 """ writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", version=HttpVersion11, writer=writer) resp = web.Response(status=status, body="answer") await resp.prepare(req) assert not resp.chunked assert hdrs.CONTENT_LENGTH not in resp.headers assert hdrs.TRANSFER_ENCODING not in resp.headers async def test_head_response_keeps_content_length_of_original_body() -> None: """Verify HEAD response keeps the content length of the original body HTTP/1.1.""" writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("HEAD", "/", version=HttpVersion11, writer=writer) resp = web.Response(status=200, body=b"answer") await resp.prepare(req) assert resp.content_length == 6 assert not resp.chunked assert resp.headers[hdrs.CONTENT_LENGTH] == "6" assert hdrs.TRANSFER_ENCODING not in resp.headers async def test_head_response_omits_content_length_when_body_unset() -> None: """Verify HEAD response omits content-length body when its unset.""" writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("HEAD", "/", version=HttpVersion11, writer=writer) resp = web.Response(status=200) await resp.prepare(req) assert resp.content_length == 0 assert not resp.chunked assert hdrs.CONTENT_LENGTH not in resp.headers assert hdrs.TRANSFER_ENCODING not in resp.headers async def test_304_response_omits_content_length_when_body_unset() -> None: """Verify 304 response omits content-length body when its unset.""" writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", version=HttpVersion11, writer=writer) resp = web.Response(status=304) await resp.prepare(req) assert resp.content_length == 0 assert not resp.chunked assert hdrs.CONTENT_LENGTH not in resp.headers assert hdrs.TRANSFER_ENCODING not in resp.headers async def test_content_length_on_chunked() -> None: req = make_request("GET", "/") resp = web.Response(body=b"answer") assert resp.content_length == 6 resp.enable_chunked_encoding() assert resp.content_length is None await resp.prepare(req) # type: ignore[unreachable] async def test_write_non_byteish() -> None: resp = web.StreamResponse() await resp.prepare(make_request("GET", "/")) with pytest.raises(AssertionError): await resp.write(123) # type: ignore[arg-type] async def test_write_before_start() -> None: resp = web.StreamResponse() with pytest.raises(RuntimeError): await resp.write(b"data") async def test_cannot_write_after_eof() -> None: resp = web.StreamResponse() req = make_request("GET", "/") await resp.prepare(req) await resp.write(b"data") await resp.write_eof() req.writer.write.reset_mock() # type: ignore[attr-defined] with pytest.raises(RuntimeError): await resp.write(b"next data") assert not req.writer.write.called # type: ignore[attr-defined] async def test___repr___after_eof() -> None: resp = web.StreamResponse() await resp.prepare(make_request("GET", "/")) await resp.write(b"data") await resp.write_eof() resp_repr = repr(resp) assert resp_repr == "" async def test_cannot_write_eof_before_headers() -> None: resp = web.StreamResponse() with pytest.raises(AssertionError): await resp.write_eof() async def test_cannot_write_eof_twice() -> None: resp = web.StreamResponse() writer = mock.create_autospec(AbstractStreamWriter, spec_set=True) writer.write.return_value = None writer.write_eof.return_value = None resp_impl = await resp.prepare(make_request("GET", "/", writer=writer)) await resp.write(b"data") assert resp_impl is not None assert resp_impl.write.called # type: ignore[attr-defined] await resp.write_eof() resp_impl.write.reset_mock() # type: ignore[attr-defined] await resp.write_eof() assert not writer.write.called def test_force_close() -> None: resp = web.StreamResponse() assert resp.keep_alive is None resp.force_close() assert resp.keep_alive is False def test_set_status_with_reason() -> None: resp = web.StreamResponse() resp.set_status(200, "Everything is fine!") assert 200 == resp.status assert "Everything is fine!" == resp.reason def test_set_status_with_empty_reason() -> None: resp = web.StreamResponse() resp.set_status(200, "") assert resp.status == 200 assert resp.reason == "" def test_set_status_reason_with_cr() -> None: resp = web.StreamResponse() with pytest.raises(ValueError, match="Reason cannot contain"): resp.set_status(200, "OK\rSet-Cookie: evil=1") def test_set_status_reason_with_lf() -> None: resp = web.StreamResponse() with pytest.raises(ValueError, match="Reason cannot contain"): resp.set_status(200, "OK\nSet-Cookie: evil=1") def test_set_status_reason_with_crlf() -> None: resp = web.StreamResponse() with pytest.raises(ValueError, match="Reason cannot contain"): resp.set_status(200, "OK\r\nSet-Cookie: evil=1") async def test_start_force_close() -> None: req = make_request("GET", "/") resp = web.StreamResponse() resp.force_close() assert not resp.keep_alive await resp.prepare(req) assert not resp.keep_alive async def test___repr__() -> None: req = make_request("GET", "/path/to") resp = web.StreamResponse(reason="foo") await resp.prepare(req) assert "" == repr(resp) def test___repr___not_prepared() -> None: resp = web.StreamResponse(reason="foo") assert "" == repr(resp) async def test_keep_alive_http10_default() -> None: req = make_request("GET", "/", version=HttpVersion10) resp = web.StreamResponse() await resp.prepare(req) assert not resp.keep_alive async def test_keep_alive_http10_switched_on() -> None: headers = CIMultiDict(Connection="keep-alive") req = make_request("GET", "/", version=HttpVersion10, headers=headers) req._message = req._message._replace(should_close=False) resp = web.StreamResponse() await resp.prepare(req) assert resp.keep_alive async def test_keep_alive_http09() -> None: headers = CIMultiDict(Connection="keep-alive") req = make_request("GET", "/", version=HttpVersion(0, 9), headers=headers) resp = web.StreamResponse() await resp.prepare(req) assert not resp.keep_alive async def test_prepare_twice() -> None: req = make_request("GET", "/") resp = web.StreamResponse() impl1 = await resp.prepare(req) impl2 = await resp.prepare(req) assert impl1 is impl2 async def test_prepare_calls_signal() -> None: app = mock.create_autospec(web.Application, spec_set=True) sig = mock.AsyncMock() app.on_response_prepare = aiosignal.Signal(app) app.on_response_prepare.append(sig) req = make_request("GET", "/", app=app) resp = web.StreamResponse() await resp.prepare(req) sig.assert_called_with(req, resp) # Response class def test_response_ctor() -> None: resp = web.Response() assert 200 == resp.status assert "OK" == resp.reason assert resp.body is None assert resp.content_length == 0 assert "CONTENT-LENGTH" not in resp.headers async def test_ctor_with_headers_and_status() -> None: resp = web.Response(body=b"body", status=201, headers={"Age": "12", "DATE": "date"}) assert 201 == resp.status assert b"body" == resp.body assert resp.headers["AGE"] == "12" req = make_mocked_request("GET", "/") await resp._start(req) assert 4 == resp.content_length assert resp.headers["CONTENT-LENGTH"] == "4" def test_ctor_content_type() -> None: resp = web.Response(content_type="application/json") assert 200 == resp.status assert "OK" == resp.reason assert 0 == resp.content_length assert CIMultiDict([("CONTENT-TYPE", "application/json")]) == resp.headers def test_ctor_text_body_combined() -> None: with pytest.raises(ValueError): web.Response(body=b"123", text="test text") async def test_ctor_text() -> None: resp = web.Response(text="test text") assert 200 == resp.status assert "OK" == resp.reason assert 9 == resp.content_length assert CIMultiDict([("CONTENT-TYPE", "text/plain; charset=utf-8")]) == resp.headers assert resp.body == b"test text" assert resp.text == "test text" resp.headers["DATE"] = "date" req = make_mocked_request("GET", "/", version=HttpVersion11) await resp._start(req) assert resp.headers["CONTENT-LENGTH"] == "9" def test_ctor_charset() -> None: resp = web.Response(text="текст", charset="koi8-r") assert "текст".encode("koi8-r") == resp.body assert "koi8-r" == resp.charset def test_ctor_charset_default_utf8() -> None: resp = web.Response(text="test test", charset=None) assert "utf-8" == resp.charset def test_ctor_charset_in_content_type() -> None: with pytest.raises(ValueError): web.Response(text="test test", content_type="text/plain; charset=utf-8") def test_ctor_charset_without_text() -> None: resp = web.Response(content_type="text/plain", charset="koi8-r") assert "koi8-r" == resp.charset def test_ctor_content_type_with_extra() -> None: resp = web.Response(text="test test", content_type="text/plain; version=0.0.4") assert resp.content_type == "text/plain" assert resp.headers["content-type"] == "text/plain; version=0.0.4; charset=utf-8" def test_invalid_content_type_parses_to_application_octect_stream() -> None: resp = web.Response(text="test test", content_type="jpeg") assert resp.content_type == "application/octet-stream" assert resp.headers["content-type"] == "jpeg; charset=utf-8" def test_ctor_both_content_type_param_and_header_with_text() -> None: with pytest.raises(ValueError): web.Response( headers={"Content-Type": "application/json"}, content_type="text/html", text="text", ) def test_ctor_both_charset_param_and_header_with_text() -> None: with pytest.raises(ValueError): web.Response( headers={"Content-Type": "application/json"}, charset="koi8-r", text="text" ) def test_ctor_both_content_type_param_and_header() -> None: with pytest.raises(ValueError): web.Response( headers={"Content-Type": "application/json"}, content_type="text/html" ) def test_ctor_both_charset_param_and_header() -> None: with pytest.raises(ValueError): web.Response(headers={"Content-Type": "application/json"}, charset="koi8-r") async def test_assign_nonbyteish_body() -> None: resp = web.Response(body=b"data") with pytest.raises(ValueError): resp.body = 123 assert b"data" == resp.body assert 4 == resp.content_length resp.headers["DATE"] = "date" req = make_mocked_request("GET", "/", version=HttpVersion11) await resp._start(req) assert resp.headers["CONTENT-LENGTH"] == "4" assert 4 == resp.content_length def test_assign_nonstr_text() -> None: resp = web.Response(text="test") with pytest.raises(AssertionError): resp.text = b"123" # type: ignore[assignment] assert b"test" == resp.body assert 4 == resp.content_length mpwriter = MultipartWriter(boundary="x") mpwriter.append_payload(StringPayload("test")) async def async_iter() -> AsyncIterator[str]: yield "foo" # pragma: no cover class CustomIO(io.IOBase): def __init__(self) -> None: self._lines = [b"", b"", b"test"] def read(self, size: int = -1) -> bytes: return self._lines.pop() @pytest.mark.parametrize( "payload,expected", ( ("test", "test"), (CustomIO(), "test"), (io.StringIO("test"), "test"), (io.TextIOWrapper(io.BytesIO(b"test")), "test"), (io.BytesIO(b"test"), "test"), (io.BufferedReader(io.BytesIO(b"test")), "test"), (async_iter(), None), (BodyPartReader(b"x", CIMultiDictProxy(CIMultiDict()), mock.Mock()), None), ( mpwriter, "--x\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 4\r\n\r\ntest", ), ), ) def test_payload_body_get_text(payload: object, expected: str | None) -> None: resp = web.Response(body=payload) if expected is None: with pytest.raises(TypeError): resp.text else: assert resp.text == expected def test_response_set_content_length() -> None: resp = web.Response() with pytest.raises(RuntimeError): resp.content_length = 1 async def test_send_headers_for_empty_body( buf: bytearray, writer: AbstractStreamWriter ) -> None: req = make_request("GET", "/", writer=writer) resp = web.Response() await resp.prepare(req) await resp.write_eof() txt = buf.decode("utf8") lines = txt.split("\r\n") assert len(lines) == 6 assert lines[0] == "HTTP/1.1 200 OK" assert lines[1] == "Content-Length: 0" assert lines[2].startswith("Date: ") assert lines[3].startswith("Server: ") assert lines[4] == lines[5] == "" async def test_render_with_body(buf: bytearray, writer: AbstractStreamWriter) -> None: req = make_request("GET", "/", writer=writer) resp = web.Response(body=b"data") await resp.prepare(req) await resp.write_eof() txt = buf.decode("utf8") lines = txt.split("\r\n") assert len(lines) == 7 assert lines[0] == "HTTP/1.1 200 OK" assert lines[1] == "Content-Length: 4" assert lines[2] == "Content-Type: application/octet-stream" assert lines[3].startswith("Date: ") assert lines[4].startswith("Server: ") assert lines[5] == "" assert lines[6] == "data" async def test_multiline_reason(buf: bytearray, writer: AbstractStreamWriter) -> None: with pytest.raises(ValueError, match=r"Reason cannot contain \\r or \\n"): web.Response(reason="Bad\r\nInjected-header: foo") async def test_send_set_cookie_header( buf: bytearray, writer: AbstractStreamWriter ) -> None: resp = web.Response() resp.cookies["name"] = "value" req = make_request("GET", "/", writer=writer) await resp.prepare(req) await resp.write_eof() txt = buf.decode("utf8") lines = txt.split("\r\n") assert len(lines) == 7 assert lines[0] == "HTTP/1.1 200 OK" assert lines[1] == "Content-Length: 0" assert lines[2] == "Set-Cookie: name=value" assert lines[3].startswith("Date: ") assert lines[4].startswith("Server: ") assert lines[5] == lines[6] == "" async def test_consecutive_write_eof() -> None: writer = mock.create_autospec(AbstractStreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", writer=writer) data = b"data" resp = web.Response(body=data) await resp.prepare(req) await resp.write_eof() await resp.write_eof() writer.write_eof.assert_called_once_with(data) def test_set_text_with_content_type() -> None: resp = web.Response() resp.content_type = "text/html" resp.text = "text" assert "text" == resp.text assert b"text" == resp.body assert "text/html" == resp.content_type def test_set_text_with_charset() -> None: resp = web.Response() resp.content_type = "text/plain" resp.charset = "KOI8-R" resp.text = "текст" assert "текст" == resp.text assert "текст".encode("koi8-r") == resp.body assert "koi8-r" == resp.charset def test_default_content_type_in_stream_response() -> None: resp = web.StreamResponse() assert resp.content_type == "application/octet-stream" def test_default_content_type_in_response() -> None: resp = web.Response() assert resp.content_type == "application/octet-stream" def test_content_type_with_set_text() -> None: resp = web.Response(text="text") assert resp.content_type == "text/plain" def test_content_type_with_set_body() -> None: resp = web.Response(body=b"body") assert resp.content_type == "application/octet-stream" def test_prepared_when_not_started() -> None: resp = web.StreamResponse() assert not resp.prepared async def test_prepared_when_started() -> None: resp = web.StreamResponse() await resp.prepare(make_request("GET", "/")) assert resp.prepared async def test_prepared_after_eof() -> None: resp = web.StreamResponse() await resp.prepare(make_request("GET", "/")) await resp.write(b"data") await resp.write_eof() assert resp.prepared async def test_drain_before_start() -> None: resp = web.StreamResponse() with pytest.raises(AssertionError): await resp.drain() async def test_changing_status_after_prepare_raises() -> None: resp = web.StreamResponse() await resp.prepare(make_request("GET", "/")) with pytest.raises(AssertionError): resp.set_status(400) def test_nonstr_text_in_ctor() -> None: with pytest.raises(TypeError): web.Response(text=b"data") # type: ignore[arg-type] def test_text_in_ctor_with_content_type() -> None: resp = web.Response(text="data", content_type="text/html") assert "data" == resp.text assert "text/html" == resp.content_type def test_text_in_ctor_with_content_type_header() -> None: resp = web.Response( text="текст", headers={"Content-Type": "text/html; charset=koi8-r"} ) assert "текст".encode("koi8-r") == resp.body assert "text/html" == resp.content_type assert "koi8-r" == resp.charset def test_text_in_ctor_with_content_type_header_multidict() -> None: headers = CIMultiDict({"Content-Type": "text/html; charset=koi8-r"}) resp = web.Response(text="текст", headers=headers) assert "текст".encode("koi8-r") == resp.body assert "text/html" == resp.content_type assert "koi8-r" == resp.charset def test_body_in_ctor_with_content_type_header_multidict() -> None: headers = CIMultiDict({"Content-Type": "text/html; charset=koi8-r"}) resp = web.Response(body="текст".encode("koi8-r"), headers=headers) assert "текст".encode("koi8-r") == resp.body assert "text/html" == resp.content_type assert "koi8-r" == resp.charset def test_text_with_empty_payload() -> None: resp = web.Response(status=200) assert resp.body is None assert resp.text is None def test_response_with_content_length_header_without_body() -> None: resp = web.Response(headers={"Content-Length": "123"}) assert resp.content_length == 123 def test_response_with_immutable_headers() -> None: resp = web.Response( text="text", headers=CIMultiDictProxy(CIMultiDict({"Header": "Value"})) ) assert resp.headers == { "Header": "Value", "Content-Type": "text/plain; charset=utf-8", } async def test_response_prepared_after_header_preparation() -> None: req = make_request("GET", "/") resp = web.StreamResponse() await resp.prepare(req) assert type(resp.headers["Server"]) is str async def _strip_server(req: web.Request, res: web.Response) -> None: assert "Server" in res.headers del res.headers["Server"] app = mock.create_autospec(web.Application, spec_set=True) app.on_response_prepare = aiosignal.Signal(app) app.on_response_prepare.append(_strip_server) req = make_request("GET", "/", app=app) resp = web.StreamResponse() await resp.prepare(req) assert "Server" not in resp.headers def test_weakref_creation() -> None: resp = web.Response() weakref.ref(resp) class TestJSONResponse: def test_content_type_is_application_json_by_default(self) -> None: resp = web.json_response("") assert "application/json" == resp.content_type def test_passing_text_only(self) -> None: resp = web.json_response(text=json.dumps("jaysawn")) assert resp.text == json.dumps("jaysawn") def test_data_and_text_raises_value_error(self) -> None: with pytest.raises(ValueError) as excinfo: web.json_response(data="foo", text="bar") expected_message = "only one of data, text, or body should be specified" assert expected_message == excinfo.value.args[0] def test_data_and_body_raises_value_error(self) -> None: with pytest.raises(ValueError) as excinfo: web.json_response(data="foo", body=b"bar") expected_message = "only one of data, text, or body should be specified" assert expected_message == excinfo.value.args[0] def test_text_is_json_encoded(self) -> None: resp = web.json_response({"foo": 42}) assert json.dumps({"foo": 42}) == resp.text def test_content_type_is_overrideable(self) -> None: resp = web.json_response({"foo": 42}, content_type="application/vnd.json+api") assert "application/vnd.json+api" == resp.content_type class TestJSONBytesResponse: def test_content_type_is_application_json_by_default(self) -> None: resp = web.json_bytes_response( "", dumps=lambda x: json.dumps(x).encode("utf-8") ) assert "application/json" == resp.content_type def test_passing_body_only(self) -> None: resp = web.json_bytes_response( dumps=lambda x: json.dumps(x).encode("utf-8"), body=b'"jaysawn"', ) assert resp.body == b'"jaysawn"' def test_data_and_body_raises_value_error(self) -> None: with pytest.raises(ValueError) as excinfo: web.json_bytes_response( data="foo", dumps=lambda x: json.dumps(x).encode("utf-8"), body=b"bar" ) expected_message = "only one of data or body should be specified" assert expected_message == excinfo.value.args[0] def test_body_is_json_encoded_bytes(self) -> None: resp = web.json_bytes_response( {"foo": 42}, dumps=lambda x: json.dumps(x).encode("utf-8") ) assert json.dumps({"foo": 42}).encode("utf-8") == resp.body def test_content_type_is_overrideable(self) -> None: resp = web.json_bytes_response( {"foo": 42}, dumps=lambda x: json.dumps(x).encode("utf-8"), content_type="application/vnd.json+api", ) assert "application/vnd.json+api" == resp.content_type def test_custom_dumps(self) -> None: resp = web.json_bytes_response( {"foo": 42}, dumps=lambda x: json.dumps(x, separators=(",", ":")).encode("utf-8"), ) assert b'{"foo":42}' == resp.body @pytest.mark.dev_mode async def test_no_warn_small_cookie( buf: bytearray, writer: AbstractStreamWriter ) -> None: resp = web.Response() resp.set_cookie("foo", "ÿ" + "8" * 4064, max_age=2600) # No warning req = make_request("GET", "/", writer=writer) await resp.prepare(req) await resp.write_eof() match = re.search(b"Set-Cookie: (.*?)\r\n", buf) assert match is not None cookie = match.group(1) assert len(cookie) == 4096 @pytest.mark.dev_mode async def test_warn_large_cookie(buf: bytearray, writer: AbstractStreamWriter) -> None: resp = web.Response() with pytest.warns( UserWarning, match="The size of is too large, it might get ignored by the client.", ): resp.set_cookie("foo", "ÿ" + "8" * 4065, max_age=2600) req = make_request("GET", "/", writer=writer) await resp.prepare(req) await resp.write_eof() match = re.search(b"Set-Cookie: (.*?)\r\n", buf) assert match is not None cookie = match.group(1) assert len(cookie) == 4097 @pytest.mark.parametrize("loose_header_type", (MultiDict, CIMultiDict, dict)) async def test_passing_cimultidict_to_web_response_not_mutated( loose_header_type: type, ) -> None: req = make_request("GET", "/") headers = loose_header_type({}) resp = web.Response(body=b"answer", headers=headers) await resp.prepare(req) assert resp.content_length == 6 assert not headers async def test_stream_response_sends_headers_immediately() -> None: """Test that StreamResponse sends headers immediately.""" writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", writer=writer) resp = web.StreamResponse() # StreamResponse should have _send_headers_immediately = True assert resp._send_headers_immediately is True # Prepare the response await resp.prepare(req) # Headers should be sent immediately writer.send_headers.assert_called_once() async def test_response_buffers_headers() -> None: """Test that Response buffers headers for packet coalescing.""" writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) req = make_request("GET", "/", writer=writer) resp = web.Response(body=b"hello") # Response should have _send_headers_immediately = False assert resp._send_headers_immediately is False # Prepare the response await resp.prepare(req) # Headers should NOT be sent immediately writer.send_headers.assert_not_called() # But write_headers should have been called writer.write_headers.assert_called_once() ================================================ FILE: tests/test_web_runner.py ================================================ import asyncio import platform import signal from collections.abc import Iterator from typing import Any, NoReturn, Protocol from unittest import mock import pytest from aiohttp import web from aiohttp.abc import AbstractAccessLogger from aiohttp.test_utils import get_unused_port_socket from aiohttp.web_log import AccessLogger class _RunnerMaker(Protocol): def __call__(self, handle_signals: bool = ..., **kwargs: Any) -> web.AppRunner: ... @pytest.fixture def app() -> web.Application: return web.Application() @pytest.fixture def make_runner( loop: asyncio.AbstractEventLoop, app: web.Application ) -> Iterator[_RunnerMaker]: asyncio.set_event_loop(loop) runners = [] def go(handle_signals: bool = False, **kwargs: Any) -> web.AppRunner: runner = web.AppRunner(app, handle_signals=handle_signals, **kwargs) runners.append(runner) return runner yield go for runner in runners: loop.run_until_complete(runner.cleanup()) async def test_site_for_nonfrozen_app(make_runner: _RunnerMaker) -> None: runner = make_runner() with pytest.raises(RuntimeError): web.TCPSite(runner) assert len(runner.sites) == 0 @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) async def test_runner_setup_handle_signals(make_runner: _RunnerMaker) -> None: # Save the original signal handler original_handler = signal.getsignal(signal.SIGTERM) try: # Set a known state for the signal handler to avoid flaky tests signal.signal(signal.SIGTERM, signal.SIG_DFL) runner = make_runner(handle_signals=True) await runner.setup() assert signal.getsignal(signal.SIGTERM) is not signal.SIG_DFL await runner.cleanup() assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL finally: # Restore original signal handler signal.signal(signal.SIGTERM, original_handler) @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) async def test_runner_setup_without_signal_handling(make_runner: _RunnerMaker) -> None: # Save the original signal handler original_handler = signal.getsignal(signal.SIGTERM) try: # Set a known state for the signal handler to avoid flaky tests signal.signal(signal.SIGTERM, signal.SIG_DFL) runner = make_runner(handle_signals=False) await runner.setup() assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL await runner.cleanup() assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL finally: # Restore original signal handler signal.signal(signal.SIGTERM, original_handler) async def test_site_double_added(make_runner: _RunnerMaker) -> None: _sock = get_unused_port_socket("127.0.0.1") runner = make_runner() await runner.setup() site = web.SockSite(runner, _sock) await site.start() with pytest.raises(RuntimeError): await site.start() assert len(runner.sites) == 1 async def test_site_stop_not_started(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) with pytest.raises(RuntimeError): await site.stop() assert len(runner.sites) == 0 async def test_custom_log_format(make_runner: _RunnerMaker) -> None: runner = make_runner(access_log_format="abc") await runner.setup() assert runner.server is not None assert runner.server._kwargs["access_log_format"] == "abc" async def test_unreg_site(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) with pytest.raises(RuntimeError): runner._unreg_site(site) async def test_app_property(make_runner: _RunnerMaker, app: web.Application) -> None: runner = make_runner() assert runner.app is app def test_non_app() -> None: with pytest.raises(TypeError): web.AppRunner(object()) # type: ignore[arg-type] def test_app_handler_args() -> None: app = web.Application(handler_args={"test": True}) runner = web.AppRunner(app) assert runner._kwargs == {"access_log_class": AccessLogger, "test": True} async def test_app_handler_args_failure() -> None: app = web.Application(handler_args={"unknown_parameter": 5}) runner = web.AppRunner(app) await runner.setup() assert runner._server rh = runner._server() assert rh._timeout_ceil_threshold == 5 await runner.cleanup() assert app @pytest.mark.parametrize( ("value", "expected"), ( (2, 2), (None, 5), ("2", 2), ), ) async def test_app_handler_args_ceil_threshold( value: int | str | None, expected: int ) -> None: app = web.Application(handler_args={"timeout_ceil_threshold": value}) runner = web.AppRunner(app) await runner.setup() assert runner._server rh = runner._server() assert rh._timeout_ceil_threshold == expected await runner.cleanup() assert app async def test_app_make_handler_access_log_class_bad_type1() -> None: class Logger: pass app = web.Application() with pytest.raises(TypeError): web.AppRunner(app, access_log_class=Logger) # type: ignore[arg-type] async def test_app_make_handler_access_log_class_bad_type2() -> None: class Logger: pass app = web.Application(handler_args={"access_log_class": Logger}) with pytest.raises(TypeError): web.AppRunner(app) async def test_app_make_handler_access_log_class1() -> None: class Logger(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: """Pass log method.""" app = web.Application() runner = web.AppRunner(app, access_log_class=Logger) assert runner._kwargs["access_log_class"] is Logger async def test_app_make_handler_access_log_class2() -> None: class Logger(AbstractAccessLogger): def log( self, request: web.BaseRequest, response: web.StreamResponse, time: float ) -> None: """Pass log method.""" app = web.Application(handler_args={"access_log_class": Logger}) runner = web.AppRunner(app) assert runner._kwargs["access_log_class"] is Logger async def test_app_make_handler_no_access_log_class() -> None: app = web.Application(handler_args={"access_log": None}) runner = web.AppRunner(app) assert runner._kwargs["access_log"] is None async def test_addresses(make_runner: _RunnerMaker, unix_sockname: str) -> None: _sock = get_unused_port_socket("127.0.0.1") runner = make_runner() await runner.setup() tcp = web.SockSite(runner, _sock) await tcp.start() unix = web.UnixSite(runner, unix_sockname) await unix.start() actual_addrs = runner.addresses expected_host, expected_post = _sock.getsockname()[:2] assert actual_addrs == [(expected_host, expected_post), unix_sockname] @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_runner_wrong_loop( app: web.Application, selector_loop: asyncio.AbstractEventLoop, pipe_name: str ) -> None: runner = web.AppRunner(app) await runner.setup() with pytest.raises(RuntimeError): web.NamedPipeSite(runner, pipe_name) @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_runner_proactor_loop( proactor_loop: asyncio.AbstractEventLoop, app: web.Application, pipe_name: str ) -> None: runner = web.AppRunner(app) await runner.setup() pipe = web.NamedPipeSite(runner, pipe_name) await pipe.start() await runner.cleanup() async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) assert site.name == "http://0.0.0.0:8080" m = mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True) m.create_server.return_value = mock.create_autospec(asyncio.Server, spec_set=True) with mock.patch( "asyncio.get_event_loop", autospec=True, spec_set=True, return_value=m ): await site.start() m.create_server.assert_called_once() args, kwargs = m.create_server.call_args assert args == (runner.server, None, 8080) async def test_tcpsite_empty_str_host(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner, host="") assert site.port == 8080 assert site.name == "http://0.0.0.0:8080" async def test_tcpsite_ephemeral_port(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner, port=0) assert site.port == 0 await site.start() assert site.port != 0 assert site.name.startswith("http://0.0.0.0:") await site.stop() def test_run_after_asyncio_run() -> None: called = False async def nothing() -> None: pass def spy() -> None: nonlocal called called = True async def shutdown() -> NoReturn: spy() raise web.GracefulExit() # asyncio.run() creates a new loop and closes it. asyncio.run(nothing()) app = web.Application() # create_task() will delay the function until app is run. app.on_startup.append(lambda a: asyncio.create_task(shutdown())) web.run_app(app) assert called, "run_app() should work after asyncio.run()." ================================================ FILE: tests/test_web_sendfile.py ================================================ import asyncio import io from pathlib import Path from stat import S_IFREG, S_IRUSR, S_IWUSR from unittest import mock from aiohttp import hdrs from aiohttp.http_writer import StreamWriter from aiohttp.test_utils import make_mocked_request from aiohttp.web_fileresponse import FileResponse MOCK_MODE = S_IFREG | S_IRUSR | S_IWUSR def test_using_gzip_if_header_present_and_file_available( loop: asyncio.AbstractEventLoop, ) -> None: request = make_mocked_request( "GET", "http://python.org/logo.png", # Header uses some uppercase to ensure case-insensitive treatment headers={hdrs.ACCEPT_ENCODING: "GZip"}, ) gz_filepath = mock.create_autospec(Path, spec_set=True) gz_filepath.lstat.return_value.st_size = 1024 gz_filepath.lstat.return_value.st_mtime_ns = 1603733507222449291 gz_filepath.lstat.return_value.st_mode = MOCK_MODE filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" filepath.with_suffix.return_value = gz_filepath file_sender = FileResponse(filepath) file_sender._path = filepath file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) assert not filepath.open.called assert gz_filepath.open.called def test_gzip_if_header_not_present_and_file_available( loop: asyncio.AbstractEventLoop, ) -> None: request = make_mocked_request("GET", "http://python.org/logo.png", headers={}) gz_filepath = mock.create_autospec(Path, spec_set=True) gz_filepath.lstat.return_value.st_size = 1024 gz_filepath.lstat.return_value.st_mtime_ns = 1603733507222449291 gz_filepath.lstat.return_value.st_mode = MOCK_MODE filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" filepath.with_suffix.return_value = gz_filepath filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath) file_sender._path = filepath file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) assert filepath.open.called assert not gz_filepath.open.called def test_gzip_if_header_not_present_and_file_not_available( loop: asyncio.AbstractEventLoop, ) -> None: request = make_mocked_request("GET", "http://python.org/logo.png", headers={}) gz_filepath = mock.create_autospec(Path, spec_set=True) gz_filepath.stat.side_effect = OSError(2, "No such file or directory") filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" filepath.with_suffix.return_value = gz_filepath filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath) file_sender._path = filepath file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) assert filepath.open.called assert not gz_filepath.open.called def test_gzip_if_header_present_and_file_not_available( loop: asyncio.AbstractEventLoop, ) -> None: request = make_mocked_request( "GET", "http://python.org/logo.png", headers={hdrs.ACCEPT_ENCODING: "gzip"} ) gz_filepath = mock.create_autospec(Path, spec_set=True) gz_filepath.lstat.side_effect = OSError(2, "No such file or directory") filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" filepath.with_suffix.return_value = gz_filepath filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath) file_sender._path = filepath file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) assert filepath.open.called assert not gz_filepath.open.called def test_status_controlled_by_user(loop: asyncio.AbstractEventLoop) -> None: request = make_mocked_request("GET", "http://python.org/logo.png", headers={}) filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath, status=203) file_sender._path = filepath file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] loop.run_until_complete(file_sender.prepare(request)) assert file_sender._status == 203 async def test_file_response_sends_headers_immediately() -> None: """Test that FileResponse sends headers immediately (inherits from StreamResponse).""" writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) request = make_mocked_request("GET", "http://python.org/logo.png", writer=writer) filepath = mock.create_autospec(Path, spec_set=True) filepath.name = "logo.png" filepath.stat.return_value.st_size = 1024 filepath.stat.return_value.st_mtime_ns = 1603733507222449291 filepath.stat.return_value.st_mode = MOCK_MODE file_sender = FileResponse(filepath) file_sender._path = filepath file_sender._sendfile = mock.AsyncMock(return_value=None) # type: ignore[method-assign] # FileResponse inherits from StreamResponse, so should send immediately assert file_sender._send_headers_immediately is True # Prepare the response await file_sender.prepare(request) # Headers should be sent immediately writer.send_headers.assert_called_once() async def test_sendfile_fallback_respects_count_boundary() -> None: """Regression test: _sendfile_fallback should not read beyond the requested count. Previously the first chunk used the full chunk_size even when count was smaller, and the loop subtracted chunk_size instead of the actual bytes read. Both bugs could cause extra data to be sent when serving range requests. """ file_data = b"A" * 100 + b"B" * 50 # 150 bytes total fobj = io.BytesIO(file_data) writer = mock.AsyncMock() written = bytearray() async def capture_write(data: bytes) -> None: written.extend(data) writer.write = capture_write writer.drain = mock.AsyncMock() file_sender = FileResponse("dummy.bin") file_sender._chunk_size = 64 # smaller than count to test multi-chunk # Request only the first 100 bytes (offset=0, count=100) await file_sender._sendfile_fallback(writer, fobj, offset=0, count=100) assert bytes(written) == b"A" * 100 assert len(written) == 100 ================================================ FILE: tests/test_web_sendfile_functional.py ================================================ import asyncio import bz2 import gzip import pathlib import socket from collections.abc import Iterable, Iterator from typing import Protocol from unittest import mock import pytest from _pytest.fixtures import SubRequest import aiohttp from aiohttp import web from aiohttp.compression_utils import ZLibBackend from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.typedefs import PathLike try: import brotlicffi as brotli except ImportError: import brotli try: import ssl except ImportError: ssl = None # type: ignore[assignment] class _Sender(Protocol): def __call__( self, path: PathLike, chunk_size: int = 256 * 1024 ) -> web.FileResponse: ... HELLO_AIOHTTP = b"Hello aiohttp! :-)\n" @pytest.fixture(scope="module") def hello_txt( request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory ) -> pathlib.Path: """Create a temp path with hello.txt and compressed versions. The uncompressed text file path is returned by default. Alternatively, an indirect parameter can be passed with an encoding to get a compressed path. """ txt = tmp_path_factory.mktemp("hello-") / "hello.txt" hello = { None: txt, "gzip": txt.with_suffix(f"{txt.suffix}.gz"), "br": txt.with_suffix(f"{txt.suffix}.br"), "bzip2": txt.with_suffix(f"{txt.suffix}.bz2"), } # Uncompressed file is not actually written to test it is not required. hello["gzip"].write_bytes(gzip.compress(HELLO_AIOHTTP)) hello["br"].write_bytes(brotli.compress(HELLO_AIOHTTP)) hello["bzip2"].write_bytes(bz2.compress(HELLO_AIOHTTP)) encoding = getattr(request, "param", None) return hello[encoding] @pytest.fixture(params=["sendfile", "no_sendfile"], ids=["sendfile", "no_sendfile"]) def sender(request: SubRequest, loop: asyncio.AbstractEventLoop) -> Iterator[_Sender]: sendfile_mock = None def maker(path: PathLike, chunk_size: int = 256 * 1024) -> web.FileResponse: ret = web.FileResponse(path, chunk_size=chunk_size) rloop = asyncio.get_running_loop() is_patched = rloop.sendfile is sendfile_mock assert is_patched if request.param == "no_sendfile" else not is_patched return ret if request.param == "no_sendfile": with mock.patch.object( loop, "sendfile", autospec=True, spec_set=True, side_effect=NotImplementedError, ) as sendfile_mock: yield maker else: yield maker @pytest.fixture def app_with_static_route(sender: _Sender) -> web.Application: filename = "data.unknown_mime_type" filepath = pathlib.Path(__file__).parent / filename async def handler(request: web.Request) -> web.FileResponse: return sender(filepath) app = web.Application() app.router.add_get("/", handler) return app async def test_static_file_ok( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) resp = await client.get("/") assert resp.status == 200 txt = await resp.text() assert "file content" == txt.rstrip() assert "application/octet-stream" == resp.headers["Content-Type"] assert resp.headers.get("Content-Encoding") is None resp.release() await client.close() async def test_zero_bytes_file_ok( aiohttp_client: AiohttpClient, sender: _Sender ) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" async def handler(request: web.Request) -> web.FileResponse: return sender(filepath) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) # Run the request multiple times to ensure # that an untrapped exception is not hidden # because there is no read of the zero bytes for i in range(2): resp = await client.get("/") assert resp.status == 200 txt = await resp.text() assert "" == txt.rstrip() assert "application/octet-stream" == resp.headers["Content-Type"] assert resp.headers.get("Content-Encoding") is None resp.release() await client.close() async def test_zero_bytes_file_mocked_native_sendfile( aiohttp_client: AiohttpClient, ) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" async def handler(request: web.Request) -> web.FileResponse: return web.FileResponse(filepath) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) # Run the request multiple times to ensure # that an untrapped exception is not hidden # because there is no read of the zero bytes for i in range(2): resp = await client.get("/") assert resp.status == 200 txt = await resp.text() assert "" == txt.rstrip() assert "application/octet-stream" == resp.headers["Content-Type"] assert resp.headers.get("Content-Encoding") is None assert resp.headers.get("Content-Length") == "0" resp.release() await client.close() async def test_static_file_ok_string_path( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) resp = await client.get("/") assert resp.status == 200 txt = await resp.text() assert "file content" == txt.rstrip() assert "application/octet-stream" == resp.headers["Content-Type"] assert resp.headers.get("Content-Encoding") is None resp.release() await client.close() async def test_static_file_not_exists(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) resp = await client.get("/fake") assert resp.status == 404 resp.release() await client.close() async def test_static_file_name_too_long(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) resp = await client.get("/x*500") assert resp.status == 404 resp.release() await client.close() async def test_static_file_upper_directory(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) resp = await client.get("/../../") assert resp.status == 404 resp.release() await client.close() async def test_static_file_with_content_type( aiohttp_client: AiohttpClient, sender: _Sender ) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.jpg" async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 200 body = await resp.read() with filepath.open("rb") as f: content = f.read() assert content == body assert resp.headers["Content-Type"] == "image/jpeg" assert resp.headers.get("Content-Encoding") is None resp.close() resp.release() await client.close() @pytest.mark.parametrize("hello_txt", ["gzip", "br"], indirect=True) async def test_static_file_custom_content_type( hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, sender: _Sender ) -> None: """Test that custom type without encoding is returned for encoded request.""" async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt, chunk_size=16) resp.content_type = "application/pdf" return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 200 assert resp.headers.get("Content-Encoding") is None assert resp.headers["Content-Type"] == "application/pdf" assert await resp.read() == hello_txt.read_bytes() resp.close() resp.release() await client.close() @pytest.mark.parametrize( ("accept_encoding", "expect_encoding"), [("gzip, deflate", "gzip"), ("gzip, deflate, br", "br")], ) async def test_static_file_custom_content_type_compress( hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, sender: _Sender, accept_encoding: str, expect_encoding: str, ) -> None: """Test that custom type with encoding is returned for unencoded requests.""" async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt, chunk_size=16) resp.content_type = "application/pdf" return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/", headers={"Accept-Encoding": accept_encoding}) assert resp.status == 200 assert resp.headers.get("Content-Encoding") == expect_encoding assert resp.headers["Content-Type"] == "application/pdf" assert await resp.read() == HELLO_AIOHTTP resp.close() resp.release() await client.close() @pytest.mark.parametrize( ("accept_encoding", "expect_encoding"), [("gzip, deflate", "gzip"), ("gzip, deflate, br", "br")], ) @pytest.mark.parametrize("forced_compression", [None, web.ContentCoding.gzip]) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_static_file_with_encoding_and_enable_compression( hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, sender: _Sender, accept_encoding: str, expect_encoding: str, forced_compression: web.ContentCoding | None, ) -> None: """Test that enable_compression does not double compress when an encoded file is also present.""" async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt) resp.enable_compression(forced_compression) return resp app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/", headers={"Accept-Encoding": accept_encoding}) assert resp.status == 200 assert resp.headers.get("Content-Encoding") == expect_encoding assert resp.headers["Content-Type"] == "text/plain" assert await resp.read() == HELLO_AIOHTTP resp.close() resp.release() await client.close() @pytest.mark.parametrize( ("hello_txt", "expect_type"), [ ("gzip", "application/gzip"), ("br", "application/x-brotli"), ("bzip2", "application/x-bzip2"), ], indirect=["hello_txt"], ) async def test_static_file_with_content_encoding( hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, sender: _Sender, expect_type: str, ) -> None: """Test requesting static compressed files returns the correct content type and encoding.""" async def handler(request: web.Request) -> web.FileResponse: return sender(hello_txt) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 200 assert resp.headers.get("Content-Encoding") is None assert resp.headers["Content-Type"] == expect_type assert await resp.read() == hello_txt.read_bytes() resp.close() resp.release() await client.close() async def test_static_file_if_modified_since( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) resp = await client.get("/") assert 200 == resp.status lastmod = resp.headers.get("Last-Modified") assert lastmod is not None resp.close() resp.release() resp = await client.get("/", headers={"If-Modified-Since": lastmod}) body = await resp.read() assert 304 == resp.status assert resp.headers.get("Content-Length") is None assert resp.headers.get("Last-Modified") == lastmod assert b"" == body resp.close() resp.release() await client.close() async def test_static_file_if_modified_since_past_date( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" resp = await client.get("/", headers={"If-Modified-Since": lastmod}) assert 200 == resp.status resp.close() resp.release() await client.close() async def test_static_file_if_modified_since_invalid_date( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" resp = await client.get("/", headers={"If-Modified-Since": lastmod}) assert 200 == resp.status resp.close() resp.release() await client.close() async def test_static_file_if_modified_since_future_date( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" resp = await client.get("/", headers={"If-Modified-Since": lastmod}) body = await resp.read() assert 304 == resp.status assert resp.headers.get("Content-Length") is None assert resp.headers.get("Last-Modified") assert b"" == body resp.close() resp.release() await client.close() @pytest.mark.parametrize("if_unmodified_since", ("", "Fri, 31 Dec 0000 23:59:59 GMT")) async def test_static_file_if_match( aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_unmodified_since: str, ) -> None: client = await aiohttp_client(app_with_static_route) resp = await client.get("/") assert 200 == resp.status original_etag = resp.headers.get("ETag") assert original_etag is not None resp.close() resp.release() headers = {"If-Match": original_etag, "If-Unmodified-Since": if_unmodified_since} resp = await client.head("/", headers=headers) body = await resp.read() assert 200 == resp.status assert resp.headers.get("ETag") assert resp.headers.get("Last-Modified") assert b"" == body resp.close() resp.release() await client.close() @pytest.mark.parametrize("if_unmodified_since", ("", "Fri, 31 Dec 0000 23:59:59 GMT")) @pytest.mark.parametrize( "etags,expected_status", [ (("*",), 200), (('"example-tag"', 'W/"weak-tag"'), 412), ], ) async def test_static_file_if_match_custom_tags( aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_unmodified_since: str, etags: tuple[str], expected_status: int, ) -> None: client = await aiohttp_client(app_with_static_route) if_match = ", ".join(etags) headers = {"If-Match": if_match, "If-Unmodified-Since": if_unmodified_since} resp = await client.head("/", headers=headers) body = await resp.read() assert expected_status == resp.status assert b"" == body resp.close() resp.release() await client.close() @pytest.mark.parametrize("if_modified_since", ("", "Fri, 31 Dec 9999 23:59:59 GMT")) @pytest.mark.parametrize( "additional_etags", ( (), ('"some-other-strong-etag"', 'W/"weak-tag"', "invalid-tag"), ), ) async def test_static_file_if_none_match( aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_modified_since: str, additional_etags: Iterable[str], ) -> None: client = await aiohttp_client(app_with_static_route) resp = await client.get("/") assert 200 == resp.status original_etag = resp.headers["ETag"] assert resp.headers.get("Last-Modified") is not None resp.close() resp.release() etag = ",".join((original_etag, *additional_etags)) resp = await client.get( "/", headers={"If-None-Match": etag, "If-Modified-Since": if_modified_since} ) body = await resp.read() assert 304 == resp.status assert resp.headers.get("Content-Length") is None assert resp.headers.get("ETag") == original_etag assert b"" == body resp.close() resp.release() await client.close() async def test_static_file_if_none_match_star( aiohttp_client: AiohttpClient, app_with_static_route: web.Application, ) -> None: client = await aiohttp_client(app_with_static_route) resp = await client.head("/", headers={"If-None-Match": "*"}) body = await resp.read() assert 304 == resp.status assert resp.headers.get("Content-Length") is None assert resp.headers.get("ETag") assert resp.headers.get("Last-Modified") assert b"" == body resp.close() resp.release() await client.close() @pytest.mark.parametrize("if_modified_since", ("", "Fri, 31 Dec 9999 23:59:59 GMT")) async def test_static_file_if_none_match_weak( aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_modified_since: str, ) -> None: client = await aiohttp_client(app_with_static_route) resp = await client.get("/") assert 200 == resp.status original_etag = resp.headers["ETag"] assert resp.headers.get("Last-Modified") is not None resp.close() resp.release() weak_etag = f"W/{original_etag}" resp = await client.get( "/", headers={"If-None-Match": weak_etag, "If-Modified-Since": if_modified_since}, ) body = await resp.read() assert 304 == resp.status assert resp.headers.get("Content-Length") is None assert resp.headers.get("ETag") == original_etag assert b"" == body resp.close() resp.release() await client.close() @pytest.mark.skipif(ssl is None, reason="ssl not supported") async def test_static_file_ssl( aiohttp_server: AiohttpServer, ssl_ctx: ssl.SSLContext, aiohttp_client: AiohttpClient, client_ssl_ctx: ssl.SSLContext, ) -> None: dirname = pathlib.Path(__file__).parent filename = "data.unknown_mime_type" app = web.Application() app.router.add_static("/static", dirname) server = await aiohttp_server(app, ssl=ssl_ctx) conn = aiohttp.TCPConnector(ssl=client_ssl_ctx) client = await aiohttp_client(server, connector=conn) resp = await client.get("/static/" + filename) assert 200 == resp.status txt = await resp.text() assert "file content" == txt.rstrip() ct = resp.headers["CONTENT-TYPE"] assert "application/octet-stream" == ct assert resp.headers.get("CONTENT-ENCODING") is None resp.release() await client.close() await conn.close() async def test_static_file_directory_traversal_attack( aiohttp_client: AiohttpClient, ) -> None: dirname = pathlib.Path(__file__).parent relpath = "../README.rst" full_path = dirname / relpath assert full_path.is_file() app = web.Application() app.router.add_static("/static", dirname) client = await aiohttp_client(app) resp = await client.get("/static/" + relpath) assert 404 == resp.status resp.release() url_relpath2 = "/static/dir/../" + relpath resp = await client.get(url_relpath2) assert 404 == resp.status resp.release() url_abspath = "/static/" + str(full_path.resolve()) resp = await client.get(url_abspath) assert resp.status == 404 resp.release() await client.close() @pytest.mark.skip_blockbuster async def test_static_file_huge( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path ) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 20MB file with file_path.open("wb") as f: for i in range(1024 * 20): f.write((chr(i % 64 + 0x20) * 1024).encode()) file_st = file_path.stat() app = web.Application() app.router.add_static("/static", str(tmp_path)) client = await aiohttp_client(app) resp = await client.get("/static/" + file_path.name) assert 200 == resp.status ct = resp.headers["CONTENT-TYPE"] assert "application/octet-stream" == ct assert resp.headers.get("CONTENT-ENCODING") is None assert int(resp.headers["CONTENT-LENGTH"]) == file_st.st_size f2 = file_path.open("rb") off = 0 cnt = 0 while off < file_st.st_size: chunk = await resp.content.readany() expected = f2.read(len(chunk)) assert chunk == expected off += len(chunk) cnt += 1 f2.close() resp.release() await client.close() async def test_static_file_range( aiohttp_client: AiohttpClient, sender: _Sender ) -> None: filepath = pathlib.Path(__file__).parent / "sample.txt" filesize = filepath.stat().st_size async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) with filepath.open("rb") as f: content = f.read() # Ensure the whole file requested in parts is correct responses = await asyncio.gather( client.get("/", headers={"Range": "bytes=0-999"}), client.get("/", headers={"Range": "bytes=1000-1999"}), client.get("/", headers={"Range": "bytes=2000-"}), ) assert len(responses) == 3 assert responses[0].status == 206, "failed 'bytes=0-999': %s" % responses[0].reason assert ( responses[0].headers["Content-Range"] == f"bytes 0-999/{filesize}" ), "failed: Content-Range Error" assert responses[1].status == 206, ( "failed 'bytes=1000-1999': %s" % responses[1].reason ) assert ( responses[1].headers["Content-Range"] == f"bytes 1000-1999/{filesize}" ), "failed: Content-Range Error" assert responses[2].status == 206, "failed 'bytes=2000-': %s" % responses[2].reason assert ( responses[2].headers["Content-Range"] == f"bytes 2000-{filesize - 1}/{filesize}" ), "failed: Content-Range Error" body = await asyncio.gather( *(resp.read() for resp in responses), ) assert len(body[0]) == 1000, "failed 'bytes=0-999', received %d bytes" % len( body[0] ) assert len(body[1]) == 1000, "failed 'bytes=1000-1999', received %d bytes" % len( body[1] ) responses[0].close() responses[1].close() responses[2].close() for resp in responses: resp.release() assert content == b"".join(body) await client.close() async def test_static_file_range_end_bigger_than_size( aiohttp_client: AiohttpClient, sender: _Sender ) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) with filepath.open("rb") as f: content = f.read() # Ensure the whole file requested in parts is correct response = await client.get("/", headers={"Range": "bytes=54000-55000"}) assert response.status == 206, ( "failed 'bytes=54000-55000': %s" % response.reason ) assert ( response.headers["Content-Range"] == "bytes 54000-54996/54997" ), "failed: Content-Range Error" body = await response.read() assert len(body) == 997, "failed 'bytes=54000-55000', received %d bytes" % len( body ) assert content[54000:] == body response.release() await client.close() async def test_static_file_range_beyond_eof( aiohttp_client: AiohttpClient, sender: _Sender ) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) # Ensure the whole file requested in parts is correct response = await client.get("/", headers={"Range": "bytes=1000000-1200000"}) assert response.status == 416, ( "failed 'bytes=1000000-1200000': %s" % response.reason ) response.release() await client.close() async def test_static_file_range_tail( aiohttp_client: AiohttpClient, sender: _Sender ) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) with filepath.open("rb") as f: content = f.read() # Ensure the tail of the file is correct resp = await client.get("/", headers={"Range": "bytes=-500"}) assert resp.status == 206, resp.reason assert ( resp.headers["Content-Range"] == "bytes 54497-54996/54997" ), "failed: Content-Range Error" body4 = await resp.read() resp.close() resp.release() assert content[-500:] == body4 # Ensure out-of-range tails could be handled resp2 = await client.get("/", headers={"Range": "bytes=-99999999999999"}) assert resp2.status == 206, resp.reason assert ( resp2.headers["Content-Range"] == "bytes 0-54996/54997" ), "failed: Content-Range Error" resp2.release() await client.close() async def test_static_file_invalid_range( aiohttp_client: AiohttpClient, sender: _Sender ) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) # range must be in bytes resp = await client.get("/", headers={"Range": "blocks=0-10"}) assert resp.status == 416, "Range must be in bytes" resp.close() resp.release() # start > end resp = await client.get("/", headers={"Range": "bytes=100-0"}) assert resp.status == 416, "Range start can't be greater than end" resp.close() resp.release() # start > end resp = await client.get("/", headers={"Range": "bytes=10-9"}) assert resp.status == 416, "Range start can't be greater than end" resp.close() resp.release() # non-number range resp = await client.get("/", headers={"Range": "bytes=a-f"}) assert resp.status == 416, "Range must be integers" resp.close() resp.release() # double dash range resp = await client.get("/", headers={"Range": "bytes=0--10"}) assert resp.status == 416, "double dash in range" resp.close() resp.release() # no range resp = await client.get("/", headers={"Range": "bytes=-"}) assert resp.status == 416, "no range given" resp.close() resp.release() await client.close() async def test_static_file_if_unmodified_since_past_with_range( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" resp = await client.get( "/", headers={"If-Unmodified-Since": lastmod, "Range": "bytes=2-"} ) assert 412 == resp.status resp.close() resp.release() await client.close() async def test_static_file_if_unmodified_since_future_with_range( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" resp = await client.get( "/", headers={"If-Unmodified-Since": lastmod, "Range": "bytes=2-"} ) assert 206 == resp.status assert resp.headers["Content-Range"] == "bytes 2-12/13" assert resp.headers["Content-Length"] == "11" resp.close() resp.release() await client.close() async def test_static_file_if_range_past_with_range( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" resp = await client.get("/", headers={"If-Range": lastmod, "Range": "bytes=2-"}) assert 200 == resp.status assert resp.headers["Content-Length"] == "13" resp.close() resp.release() await client.close() async def test_static_file_if_range_future_with_range( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" resp = await client.get("/", headers={"If-Range": lastmod, "Range": "bytes=2-"}) assert 206 == resp.status assert resp.headers["Content-Range"] == "bytes 2-12/13" assert resp.headers["Content-Length"] == "11" resp.close() resp.release() await client.close() async def test_static_file_if_unmodified_since_past_without_range( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" resp = await client.get("/", headers={"If-Unmodified-Since": lastmod}) assert 412 == resp.status resp.close() resp.release() await client.close() async def test_static_file_if_unmodified_since_future_without_range( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" resp = await client.get("/", headers={"If-Unmodified-Since": lastmod}) assert 200 == resp.status assert resp.headers["Content-Length"] == "13" resp.close() resp.release() await client.close() async def test_static_file_if_range_past_without_range( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" resp = await client.get("/", headers={"If-Range": lastmod}) assert 200 == resp.status assert resp.headers["Content-Length"] == "13" resp.close() resp.release() await client.close() async def test_static_file_if_range_future_without_range( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" resp = await client.get("/", headers={"If-Range": lastmod}) assert 200 == resp.status assert resp.headers["Content-Length"] == "13" resp.close() resp.release() await client.close() async def test_static_file_if_unmodified_since_invalid_date( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" resp = await client.get("/", headers={"If-Unmodified-Since": lastmod}) assert 200 == resp.status resp.close() resp.release() await client.close() async def test_static_file_if_range_invalid_date( aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" resp = await client.get("/", headers={"If-Range": lastmod}) assert 200 == resp.status resp.close() resp.release() await client.close() @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_static_file_compression( aiohttp_client: AiohttpClient, sender: _Sender, ) -> None: filepath = pathlib.Path(__file__).parent / "data.unknown_mime_type" async def handler(request: web.Request) -> web.FileResponse: ret = sender(filepath) ret.enable_compression() return ret app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, auto_decompress=False) resp = await client.get("/") assert resp.status == 200 zcomp = ZLibBackend.compressobj(wbits=ZLibBackend.MAX_WBITS) expected_body = zcomp.compress(b"file content\n") + zcomp.flush() assert expected_body == await resp.read() assert "application/octet-stream" == resp.headers["Content-Type"] assert resp.headers.get("Content-Encoding") == "deflate" resp.release() await client.close() @pytest.mark.skip_blockbuster async def test_static_file_huge_cancel( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path ) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 100MB file with file_path.open("wb") as f: for i in range(1024 * 20): f.write((chr(i % 64 + 0x20) * 1024).encode()) task = None async def handler(request: web.Request) -> web.FileResponse: nonlocal task task = request.task # reduce send buffer size tr = request.transport assert tr is not None sock = tr.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) ret = web.FileResponse(file_path) return ret app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 200 assert task is not None task.cancel() await asyncio.sleep(0) data = b"" while True: try: data += await resp.content.read(1024) except aiohttp.ClientPayloadError: break assert len(data) < 1024 * 1024 * 20 resp.release() await client.close() async def test_static_file_huge_error( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path ) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 20MB file with file_path.open("wb") as f: f.seek(20 * 1024 * 1024) f.write(b"1") async def handler(request: web.Request) -> web.FileResponse: # reduce send buffer size tr = request.transport assert tr is not None sock = tr.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) ret = web.FileResponse(file_path) return ret app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 200 # raise an exception on server side resp.close() resp.release() await client.close() ================================================ FILE: tests/test_web_server.py ================================================ import asyncio import socket from contextlib import suppress from typing import NoReturn from unittest import mock import pytest from aiohttp import client, web from aiohttp.http_exceptions import BadHttpMethod, BadStatusLine from aiohttp.pytest_plugin import AiohttpClient, AiohttpRawServer async def test_simple_server( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.BaseRequest) -> web.Response: return web.Response(text=str(request.rel_url)) server = await aiohttp_raw_server(handler) cli = await aiohttp_client(server) resp = await cli.get("/path/to") assert resp.status == 200 txt = await resp.text() assert txt == "/path/to" async def test_unsupported_upgrade( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: # don't fail if a client probes for an unsupported protocol upgrade # https://github.com/aio-libs/aiohttp/issues/6446#issuecomment-999032039 async def handler(request: web.BaseRequest) -> web.Response: return web.Response(body=await request.read()) upgrade_headers = {"Connection": "Upgrade", "Upgrade": "unsupported_proto"} server = await aiohttp_raw_server(handler) cli = await aiohttp_client(server) test_data = b"Test" resp = await cli.post("/path/to", data=test_data, headers=upgrade_headers) assert resp.status == 200 data = await resp.read() assert data == test_data async def test_raw_server_not_http_exception( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop, ) -> None: # disable debug mode not to print traceback loop.set_debug(False) exc = RuntimeError("custom runtime error") async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to") assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/plain") txt = await resp.text() assert txt.startswith("500 Internal Server Error") assert "Traceback" not in txt logger.exception.assert_called_with( "Error handling request from %s", cli.host, exc_info=exc ) async def test_raw_server_logs_invalid_method_with_loop_debug( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop, ) -> None: exc = BadHttpMethod(b"\x16\x03\x03\x01F\x01".decode(), "error") async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() loop.set_debug(True) logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to") assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/plain") txt = await resp.text() assert "Traceback (most recent call last):\n" in txt # BadHttpMethod should be logged as debug # on the first request since the client may # be probing for TLS/SSL support which is # expected to fail logger.debug.assert_called_with( "Error handling request from %s", cli.host, exc_info=exc ) logger.debug.reset_mock() # Now make another connection to the server # to make sure that the exception is logged # at debug on a second fresh connection cli2 = await aiohttp_client(server) resp = await cli2.get("/path/to") assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/plain") # BadHttpMethod should be logged as debug # on the first request since the client may # be probing for TLS/SSL support which is # expected to fail logger.debug.assert_called_with( "Error handling request from %s", cli.host, exc_info=exc ) async def test_raw_server_logs_invalid_method_without_loop_debug( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop, ) -> None: exc = BadHttpMethod(b"\x16\x03\x03\x01F\x01".decode(), "error") async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() loop.set_debug(False) logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to") assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/plain") txt = await resp.text() assert "Traceback (most recent call last):\n" not in txt # BadHttpMethod should be logged as debug # on the first request since the client may # be probing for TLS/SSL support which is # expected to fail logger.debug.assert_called_with( "Error handling request from %s", cli.host, exc_info=exc ) async def test_raw_server_logs_invalid_method_second_request( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop, ) -> None: exc = BadHttpMethod(b"\x16\x03\x03\x01F\x01".decode(), "error") request_count = 0 async def handler(request: web.BaseRequest) -> web.Response: nonlocal request_count request_count += 1 if request_count == 2: raise exc return web.Response() loop = asyncio.get_event_loop() loop.set_debug(False) logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to") assert resp.status == 200 resp = await cli.get("/path/to") assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/plain") # BadHttpMethod should be logged as an exception # if its not the first request since we know # that the client already was speaking HTTP logger.exception.assert_called_with( "Error handling request from %s", cli.host, exc_info=exc ) async def test_raw_server_logs_bad_status_line_as_exception( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop, ) -> None: exc = BadStatusLine(b"\x16\x03\x03\x01F\x01".decode(), "error") async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() loop.set_debug(False) logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to") assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/plain") txt = await resp.text() assert "Traceback (most recent call last):\n" not in txt logger.exception.assert_called_with( "Error handling request from %s", cli.host, exc_info=exc ) async def test_raw_server_handler_timeout( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: loop = asyncio.get_event_loop() loop.set_debug(True) exc = asyncio.TimeoutError("error") async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to") assert resp.status == 504 await resp.text() logger.debug.assert_called_with("Request handler timed out.", exc_info=exc) async def test_raw_server_do_not_swallow_exceptions( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.BaseRequest) -> NoReturn: raise asyncio.CancelledError() loop = asyncio.get_event_loop() loop.set_debug(True) logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) with pytest.raises(client.ServerDisconnectedError): await cli.get("/path/to") logger.debug.assert_called_with("Ignored premature client disconnection") async def test_raw_server_does_not_swallow_base_exceptions( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: class UnexpectedException(BaseException): """Dummy base exception.""" async def handler(request: web.BaseRequest) -> NoReturn: raise UnexpectedException() loop = asyncio.get_event_loop() loop.set_debug(True) server = await aiohttp_raw_server(handler) cli = await aiohttp_client(server) with pytest.raises(client.ServerDisconnectedError): await cli.get("/path/to", timeout=client.ClientTimeout(10)) async def test_raw_server_cancelled_in_write_eof( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: class MyResponse(web.Response): async def write_eof(self, data: bytes = b"") -> NoReturn: raise asyncio.CancelledError("error") async def handler(request: web.BaseRequest) -> MyResponse: resp = MyResponse(text=str(request.rel_url)) return resp loop = asyncio.get_event_loop() loop.set_debug(True) logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) with pytest.raises(client.ServerDisconnectedError): await cli.get("/path/to") logger.debug.assert_called_with("Ignored premature client disconnection") async def test_raw_server_not_http_exception_debug( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: exc = RuntimeError("custom runtime error") async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() loop.set_debug(True) logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to") assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/plain") txt = await resp.text() assert "Traceback (most recent call last):\n" in txt logger.exception.assert_called_with( "Error handling request from %s", cli.host, exc_info=exc ) async def test_raw_server_html_exception( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop, ) -> None: # disable debug mode not to print traceback loop.set_debug(False) exc = RuntimeError("custom runtime error") async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to", headers={"Accept": "text/html"}) assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/html") txt = await resp.text() assert txt == ( "500 Internal Server Error\n" "

    500 Internal Server Error

    \n" "Server got itself in trouble\n" "\n" ) logger.exception.assert_called_with( "Error handling request from %s", cli.host, exc_info=exc ) async def test_raw_server_html_exception_debug( aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: exc = RuntimeError("custom runtime error") async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() loop.set_debug(True) logger = mock.Mock() server = await aiohttp_raw_server(handler, logger=logger) cli = await aiohttp_client(server) resp = await cli.get("/path/to", headers={"Accept": "text/html"}) assert resp.status == 500 assert resp.headers["Content-Type"].startswith("text/html") txt = await resp.text() assert txt.startswith( "500 Internal Server Error\n" "

    500 Internal Server Error

    \n" "

    Traceback:

    \n" "
    Traceback (most recent call last):\n"
        )
    
        logger.exception.assert_called_with(
            "Error handling request from %s", cli.host, exc_info=exc
        )
    
    
    async def test_handler_cancellation(unused_port_socket: socket.socket) -> None:
        event = asyncio.Event()
        sock = unused_port_socket
        port = sock.getsockname()[1]
    
        async def on_request(request: web.Request) -> web.Response:
            try:
                await asyncio.sleep(10)
            except asyncio.CancelledError:
                event.set()
                raise
            assert False
    
        app = web.Application()
        app.router.add_route("GET", "/", on_request)
    
        runner = web.AppRunner(app, handler_cancellation=True)
        await runner.setup()
    
        site = web.SockSite(runner, sock=sock)
    
        await site.start()
        assert runner.server is not None
        try:
            assert runner.server.handler_cancellation, "Flag was not propagated"
    
            async with client.ClientSession(
                timeout=client.ClientTimeout(total=0.15)
            ) as sess:
                with pytest.raises(asyncio.TimeoutError):
                    await sess.get(f"http://127.0.0.1:{port}/")
    
            with suppress(asyncio.TimeoutError):
                await asyncio.wait_for(event.wait(), timeout=1)
            assert event.is_set(), "Request handler hasn't been cancelled"
        finally:
            await asyncio.gather(runner.shutdown(), site.stop())
    
    
    async def test_no_handler_cancellation(unused_port_socket: socket.socket) -> None:
        timeout_event = asyncio.Event()
        done_event = asyncio.Event()
        sock = unused_port_socket
        port = sock.getsockname()[1]
        started = False
    
        async def on_request(request: web.Request) -> web.Response:
            nonlocal started
            started = True
            await asyncio.wait_for(timeout_event.wait(), timeout=5)
            done_event.set()
            return web.Response()
    
        app = web.Application()
        app.router.add_route("GET", "/", on_request)
    
        runner = web.AppRunner(app)
        await runner.setup()
    
        site = web.SockSite(runner, sock=sock)
    
        await site.start()
        try:
            async with client.ClientSession(
                timeout=client.ClientTimeout(total=0.2)
            ) as sess:
                with pytest.raises(asyncio.TimeoutError):
                    await sess.get(f"http://127.0.0.1:{port}/")
            await asyncio.sleep(0.1)
            timeout_event.set()
    
            with suppress(asyncio.TimeoutError):
                await asyncio.wait_for(done_event.wait(), timeout=1)
            assert started
            assert done_event.is_set()
        finally:
            await asyncio.gather(runner.shutdown(), site.stop())
    
    
    ================================================
    FILE: tests/test_web_urldispatcher.py
    ================================================
    import asyncio
    import functools
    import os
    import pathlib
    import socket
    import sys
    from collections.abc import Generator
    from stat import S_IFIFO, S_IMODE
    from typing import Any, NoReturn
    
    import pytest
    import yarl
    
    from aiohttp import web
    from aiohttp.pytest_plugin import AiohttpClient
    from aiohttp.web_urldispatcher import Resource, SystemRoute
    
    
    @pytest.mark.parametrize(
        "show_index,status,prefix,request_path,data",
        [
            pytest.param(False, 403, "/", "/", None, id="index_forbidden"),
            pytest.param(
                True,
                200,
                "/",
                "/",
                b"\n\nIndex of /.\n\n\n

    Index of" b' /.

    \n
    \n\n", ), pytest.param( True, 200, "/static", "/static", b"\n\nIndex of /.\n\n\n

    Index of" b' /.

    \n\n\n', id="index_static", ), pytest.param( True, 200, "/static", "/static/my_dir", b"\n\nIndex of /my_dir\n\n\n

    " b'Index of /my_dir

    \n\n\n", id="index_subdir", ), pytest.param( True, 200, "/static", "/static/", b"\n\nIndex of /.\n\n\n

    Index of" b' /.

    \n\n\n', id="index_static_trailing_slash", ), pytest.param( True, 200, "/static", "/static/my_dir/", b"\n\nIndex of /my_dir\n\n\n

    " b'Index of /my_dir

    \n\n\n", id="index_subdir_trailing_slash", ), ], ) async def test_access_root_of_static_handler( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient, show_index: bool, status: int, prefix: str, request_path: str, data: bytes | None, ) -> None: # Tests the operation of static file server. # Try to access the root of static file server, and make # sure that correct HTTP statuses are returned depending if we directory # index should be shown or not. my_file = tmp_path / "my_file" my_dir = tmp_path / "my_dir" my_dir.mkdir() my_file_in_dir = my_dir / "my_file_in_dir" with my_file.open("w") as fw: fw.write("hello") with my_file_in_dir.open("w") as fw: fw.write("world") app = web.Application() # Register global static route: app.router.add_static(prefix, str(tmp_path), show_index=show_index) client = await aiohttp_client(app) # Request the root of the static directory. async with await client.get(request_path) as r: assert r.status == status if data: assert r.headers["Content-Type"] == "text/html; charset=utf-8" read_ = await r.read() assert read_ == data @pytest.mark.internal # Dependent on filesystem @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Invalid filenames on some filesystems (like Windows)", ) @pytest.mark.parametrize( "show_index,status,prefix,request_path,data", [ pytest.param(False, 403, "/", "/", None, id="index_forbidden"), pytest.param( True, 200, "/", "/", b"\n\nIndex of /.\n\n\n

    Index of" b' /.

    \n\n\n", ), pytest.param( True, 200, "/static", "/static", b"\n\nIndex of /.\n\n\n

    Index of" b' /.

    \n\n\n", id="index_static", ), pytest.param( True, 200, "/static", "/static/.dir", b"\n\nIndex of /<img src=0 onerror=alert(1)>.dir</t" b"itle>\n</head>\n<body>\n<h1>Index of /<img src=0 onerror=alert(1)>.di" b'r</h1>\n<ul>\n<li><a href="/static/%3Cimg%20src=0%20onerror=alert(1)%3E.di' b'r/my_file_in_dir">my_file_in_dir</a></li>\n</ul>\n</body>\n</html>', id="index_subdir", ), ], ) async def test_access_root_of_static_handler_xss( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient, show_index: bool, status: int, prefix: str, request_path: str, data: bytes | None, ) -> None: # Tests the operation of static file server. # Try to access the root of static file server, and make # sure that correct HTTP statuses are returned depending if we directory # index should be shown or not. # Ensure that html in file names is escaped. # Ensure that links are url quoted. my_file = tmp_path / "<img src=0 onerror=alert(1)>.txt" my_dir = tmp_path / "<img src=0 onerror=alert(1)>.dir" my_dir.mkdir() my_file_in_dir = my_dir / "my_file_in_dir" with my_file.open("w") as fw: fw.write("hello") with my_file_in_dir.open("w") as fw: fw.write("world") app = web.Application() # Register global static route: app.router.add_static(prefix, str(tmp_path), show_index=show_index) client = await aiohttp_client(app) # Request the root of the static directory. async with await client.get(request_path) as r: assert r.status == status if data: assert r.headers["Content-Type"] == "text/html; charset=utf-8" read_ = await r.read() assert read_ == data async def test_follow_symlink( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: # Tests the access to a symlink, in static folder data = "hello world" my_dir_path = tmp_path / "my_dir" my_dir_path.mkdir() my_file_path = my_dir_path / "my_file_in_dir" with my_file_path.open("w") as fw: fw.write(data) my_symlink_path = tmp_path / "my_symlink" pathlib.Path(str(my_symlink_path)).symlink_to(str(my_dir_path), True) app = web.Application() # Register global static route: app.router.add_static("/", str(tmp_path), follow_symlinks=True) client = await aiohttp_client(app) # Request the root of the static directory. r = await client.get("/my_symlink/my_file_in_dir") assert r.status == 200 assert (await r.text()) == data async def test_follow_symlink_directory_traversal( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: # Tests that follow_symlinks does not allow directory transversal data = "private" private_file = tmp_path / "private_file" private_file.write_text(data) safe_path = tmp_path / "safe_dir" safe_path.mkdir() app = web.Application() # Register global static route: app.router.add_static("/", str(safe_path), follow_symlinks=True) client = await aiohttp_client(app) await client.start_server() # We need to use a raw socket to test this, as the client will normalize # the path before sending it to the server. reader, writer = await asyncio.open_connection(client.host, client.port) writer.write(b"GET /../private_file HTTP/1.1\r\n\r\n") response = await reader.readuntil(b"\r\n\r\n") assert b"404 Not Found" in response writer.close() await writer.wait_closed() await client.close() async def test_follow_symlink_directory_traversal_after_normalization( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: # Tests that follow_symlinks does not allow directory transversal # after normalization # # Directory structure # |-- secret_dir # | |-- private_file (should never be accessible) # | |-- symlink_target_dir # | |-- symlink_target_file (should be accessible via the my_symlink symlink) # | |-- sandbox_dir # | |-- my_symlink -> symlink_target_dir # secret_path = tmp_path / "secret_dir" secret_path.mkdir() # This file is below the symlink target and should not be reachable private_file = secret_path / "private_file" private_file.write_text("private") symlink_target_path = secret_path / "symlink_target_dir" symlink_target_path.mkdir() sandbox_path = symlink_target_path / "sandbox_dir" sandbox_path.mkdir() # This file should be reachable via the symlink symlink_target_file = symlink_target_path / "symlink_target_file" symlink_target_file.write_text("readable") my_symlink_path = sandbox_path / "my_symlink" pathlib.Path(str(my_symlink_path)).symlink_to(str(symlink_target_path), True) app = web.Application() # Register global static route: app.router.add_static("/", str(sandbox_path), follow_symlinks=True) client = await aiohttp_client(app) await client.start_server() # We need to use a raw socket to test this, as the client will normalize # the path before sending it to the server. reader, writer = await asyncio.open_connection(client.host, client.port) writer.write(b"GET /my_symlink/../private_file HTTP/1.1\r\n\r\n") response = await reader.readuntil(b"\r\n\r\n") assert b"404 Not Found" in response writer.close() await writer.wait_closed() reader, writer = await asyncio.open_connection(client.host, client.port) writer.write(b"GET /my_symlink/symlink_target_file HTTP/1.1\r\n\r\n") response = await reader.readuntil(b"\r\n\r\n") assert b"200 OK" in response response = await reader.readuntil(b"readable") assert response == b"readable" writer.close() await writer.wait_closed() await client.close() @pytest.mark.parametrize( "dir_name,filename,data", [ ("", "test file.txt", "test text"), ("test dir name", "test dir file .txt", "test text file folder"), ], ) async def test_access_to_the_file_with_spaces( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient, dir_name: str, filename: str, data: str, ) -> None: # Checks operation of static files with spaces my_dir_path = tmp_path / dir_name if my_dir_path != tmp_path: my_dir_path.mkdir() my_file_path = my_dir_path / filename with my_file_path.open("w") as fw: fw.write(data) app = web.Application() url = "/" + str(pathlib.Path(dir_name, filename)) app.router.add_static("/", str(tmp_path)) client = await aiohttp_client(app) r = await client.get(url) assert r.status == 200 assert (await r.text()) == data async def test_access_non_existing_resource( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: # Tests accessing non-existing resource # Try to access a non-exiting resource and make sure that 404 HTTP status # returned. app = web.Application() # Register global static route: app.router.add_static("/", str(tmp_path), show_index=True) client = await aiohttp_client(app) # Request the root of the static directory. async with client.get("/non_existing_resource") as r: assert r.status == 404 @pytest.mark.parametrize( "registered_path,request_url", [ ("/a:b", "/a:b"), ("/a@b", "/a@b"), ("/a:b", "/a%3Ab"), ], ) async def test_url_escaping( aiohttp_client: AiohttpClient, registered_path: str, request_url: str ) -> None: # Tests accessing a resource with app = web.Application() async def handler(request: web.Request) -> web.Response: return web.Response() app.router.add_get(registered_path, handler) client = await aiohttp_client(app) async with client.get(request_url) as r: assert r.status == 200 async def test_handler_metadata_persistence() -> None: # Tests accessing metadata of a handler after registering it on the app # router. app = web.Application() async def async_handler(request: web.Request) -> web.Response: """Doc""" assert False app.router.add_get("/async", async_handler) for resource in app.router.resources(): for route in resource: assert route.handler.__doc__ == "Doc" @pytest.mark.skipif( sys.platform.startswith("win32"), reason="Cannot remove read access on Windows" ) @pytest.mark.parametrize("file_request", ["", "my_file.txt"]) async def test_static_directory_without_read_permission( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient, file_request: str ) -> None: """Test static directory without read permission receives forbidden response.""" my_dir = tmp_path / "my_dir" my_dir.mkdir() my_dir.chmod(0o000) app = web.Application() app.router.add_static("/", str(tmp_path), show_index=True) client = await aiohttp_client(app) async with client.get(f"/{my_dir.name}/{file_request}") as r: assert r.status == 403 @pytest.mark.parametrize("file_request", ["", "my_file.txt"]) async def test_static_directory_with_mock_permission_error( monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, aiohttp_client: AiohttpClient, file_request: str, ) -> None: """Test static directory with mock permission errors receives forbidden response.""" my_dir = tmp_path / "my_dir" my_dir.mkdir() real_iterdir = pathlib.Path.iterdir real_is_dir = pathlib.Path.is_dir def mock_iterdir(self: pathlib.Path) -> Generator[pathlib.Path, None, None]: if my_dir.samefile(self): raise PermissionError() return real_iterdir(self) def mock_is_dir(self: pathlib.Path, **kwargs: Any) -> bool: if my_dir.samefile(self.parent): raise PermissionError() return real_is_dir(self, **kwargs) monkeypatch.setattr("pathlib.Path.iterdir", mock_iterdir) monkeypatch.setattr("pathlib.Path.is_dir", mock_is_dir) app = web.Application() app.router.add_static("/", str(tmp_path), show_index=True) client = await aiohttp_client(app) async with client.get("/") as r: assert r.status == 200 async with client.get(f"/{my_dir.name}/{file_request}") as r: assert r.status == 403 @pytest.mark.skipif( sys.platform.startswith("win32"), reason="Cannot remove read access on Windows" ) async def test_static_file_without_read_permission( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: """Test static file without read permission receives forbidden response.""" my_file = tmp_path / "my_file.txt" my_file.write_text("secret") my_file.chmod(0o000) app = web.Application() app.router.add_static("/", str(tmp_path)) client = await aiohttp_client(app) async with client.get(f"/{my_file.name}") as r: assert r.status == 403 async def test_static_file_with_mock_permission_error( monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, aiohttp_client: AiohttpClient, ) -> None: """Test static file with mock permission errors receives forbidden response.""" my_file = tmp_path / "my_file.txt" my_file.write_text("secret") my_readable = tmp_path / "my_readable.txt" my_readable.write_text("info") real_open = pathlib.Path.open def mock_open(self: pathlib.Path, *args: Any, **kwargs: Any) -> Any: if my_file.samefile(self): raise PermissionError() return real_open(self, *args, **kwargs) monkeypatch.setattr("pathlib.Path.open", mock_open) app = web.Application() app.router.add_static("/", str(tmp_path)) client = await aiohttp_client(app) # Test the mock only applies to my_file, then test the permission error. async with client.get(f"/{my_readable.name}") as r: assert r.status == 200 async with client.get(f"/{my_file.name}") as r: assert r.status == 403 async def test_access_symlink_loop( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: # Tests the access to a looped symlink, which could not be resolved. my_dir_path = tmp_path / "my_symlink" pathlib.Path(str(my_dir_path)).symlink_to(str(my_dir_path), True) app = web.Application() # Register global static route: app.router.add_static("/", str(tmp_path), show_index=True) client = await aiohttp_client(app) # Request the root of the static directory. async with client.get("/" + my_dir_path.name) as r: assert r.status == 404 async def test_access_compressed_file_as_symlink( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: """Test that compressed file variants as symlinks are ignored.""" private_file = tmp_path / "private.txt" private_file.write_text("private info") www_dir = tmp_path / "www" www_dir.mkdir() gz_link = www_dir / "file.txt.gz" gz_link.symlink_to(f"../{private_file.name}") app = web.Application() app.router.add_static("/", www_dir) client = await aiohttp_client(app) # Symlink should be ignored; response reflects missing uncompressed file. async with client.get(f"/{gz_link.stem}", auto_decompress=False) as resp: assert resp.status == 404 # Again symlin is ignored, and then uncompressed is served. txt_file = gz_link.with_suffix("") txt_file.write_text("public data") resp = await client.get(f"/{txt_file.name}") assert resp.status == 200 assert resp.headers.get("Content-Encoding") is None assert resp.content_type == "text/plain" assert await resp.text() == "public data" resp.release() await client.close() async def test_access_special_resource( unix_sockname: str, aiohttp_client: AiohttpClient ) -> None: """Test access to non-regular files is forbidden using a UNIX domain socket.""" if not getattr(socket, "AF_UNIX", None): # pragma: no cover pytest.skip("UNIX domain sockets not supported") my_special = pathlib.Path(unix_sockname) tmp_path = my_special.parent my_socket = socket.socket(socket.AF_UNIX) my_socket.bind(str(my_special)) assert my_special.is_socket() app = web.Application() app.router.add_static("/", str(tmp_path)) client = await aiohttp_client(app) async with client.get(f"/{my_special.name}") as r: assert r.status == 403 my_socket.close() async def test_access_mock_special_resource( monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, aiohttp_client: AiohttpClient, ) -> None: """Test access to non-regular files is forbidden using a mock FIFO.""" my_special = tmp_path / "my_special" my_special.touch() real_result = my_special.stat() real_stat = os.stat def mock_stat(path: Any, **kwargs: Any) -> os.stat_result: s = real_stat(path, **kwargs) if os.path.samestat(s, real_result): mock_mode = S_IFIFO | S_IMODE(s.st_mode) s = os.stat_result([mock_mode] + list(s)[1:]) return s monkeypatch.setattr("pathlib.Path.stat", mock_stat) monkeypatch.setattr("os.stat", mock_stat) app = web.Application() app.router.add_static("/", str(tmp_path)) client = await aiohttp_client(app) async with client.get(f"/{my_special.name}") as r: assert r.status == 403 async def test_partially_applied_handler(aiohttp_client: AiohttpClient) -> None: app = web.Application() async def handler(data: bytes, request: web.Request) -> web.Response: return web.Response(body=data) app.router.add_route("GET", "/", functools.partial(handler, b"hello")) client = await aiohttp_client(app) r = await client.get("/") data = await r.read() assert data == b"hello" async def test_static_head( tmp_path: pathlib.Path, aiohttp_client: AiohttpClient ) -> None: # Test HEAD on static route my_file_path = tmp_path / "test.txt" with my_file_path.open("wb") as fw: fw.write(b"should_not_see_this\n") app = web.Application() app.router.add_static("/", str(tmp_path)) client = await aiohttp_client(app) async with client.head("/test.txt") as r: assert r.status == 200 # Check that there is no content sent (see #4809). This can't easily be # done with aiohttp_client because the buffering can consume the content. reader, writer = await asyncio.open_connection(client.host, client.port) writer.write(b"HEAD /test.txt HTTP/1.1\r\n") writer.write(b"Host: localhost\r\n") writer.write(b"Connection: close\r\n") writer.write(b"\r\n") while await reader.readline() != b"\r\n": pass content = await reader.read() writer.close() assert content == b"" def test_system_route() -> None: route = SystemRoute(web.HTTPCreated(reason="test")) with pytest.raises(RuntimeError): route.url_for() assert route.name is None assert route.resource is None assert "<SystemRoute 201: test>" == repr(route) assert 201 == route.status assert "test" == route.reason async def test_allow_head(aiohttp_client: AiohttpClient) -> None: # Test allow_head on routes. app = web.Application() async def handler(request: web.Request) -> web.Response: return web.Response() app.router.add_get("/a", handler, name="a") app.router.add_get("/b", handler, allow_head=False, name="b") client = await aiohttp_client(app) async with client.get("/a") as r: assert r.status == 200 async with client.head("/a") as r: assert r.status == 200 async with client.get("/b") as r: assert r.status == 200 async with client.head("/b") as r: assert r.status == 405 @pytest.mark.parametrize( "path", ( "/a", "/{a}", "/{a:.*}", ), ) def test_reuse_last_added_resource(path: str) -> None: # Test that adding a route with the same name and path of the last added # resource doesn't create a new resource. app = web.Application() async def handler(request: web.Request) -> web.Response: assert False app.router.add_get(path, handler, name="a") app.router.add_post(path, handler, name="a") assert len(app.router.resources()) == 1 def test_resource_raw_match() -> None: app = web.Application() async def handler(request: web.Request) -> web.Response: assert False route = app.router.add_get("/a", handler, name="a") assert route.resource is not None assert route.resource.raw_match("/a") route = app.router.add_get("/{b}", handler, name="b") assert route.resource is not None assert route.resource.raw_match("/{b}") resource = app.router.add_static("/static", ".") assert not resource.raw_match("/static") async def test_add_view(aiohttp_client: AiohttpClient) -> None: app = web.Application() class MyView(web.View): async def get(self) -> web.Response: return web.Response() async def post(self) -> web.Response: return web.Response() app.router.add_view("/a", MyView) client = await aiohttp_client(app) async with client.get("/a") as r: assert r.status == 200 async with client.post("/a") as r: assert r.status == 200 async with client.put("/a") as r: assert r.status == 405 async def test_decorate_view(aiohttp_client: AiohttpClient) -> None: routes = web.RouteTableDef() @routes.view("/a") class MyView(web.View): async def get(self) -> web.Response: return web.Response() async def post(self) -> web.Response: return web.Response() app = web.Application() app.router.add_routes(routes) client = await aiohttp_client(app) async with client.get("/a") as r: assert r.status == 200 async with client.post("/a") as r: assert r.status == 200 async with client.put("/a") as r: assert r.status == 405 async def test_web_view(aiohttp_client: AiohttpClient) -> None: app = web.Application() class MyView(web.View): async def get(self) -> web.Response: return web.Response() async def post(self) -> web.Response: return web.Response() app.router.add_routes([web.view("/a", MyView)]) client = await aiohttp_client(app) async with client.get("/a") as r: assert r.status == 200 async with client.post("/a") as r: assert r.status == 200 async with client.put("/a") as r: assert r.status == 405 async def test_static_absolute_url( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path ) -> None: # requested url is an absolute name like # /static/\\machine_name\c$ or /static/D:\path # where the static dir is totally different app = web.Application() file_path = tmp_path / "file.txt" file_path.write_text("sample text", "ascii") here = pathlib.Path(__file__).parent app.router.add_static("/static", here) client = await aiohttp_client(app) async with client.get("/static/" + str(file_path.resolve())) as resp: assert resp.status == 404 async def test_for_issue_5250( aiohttp_client: AiohttpClient, tmp_path: pathlib.Path ) -> None: app = web.Application() app.router.add_static("/foo", tmp_path) async def get_foobar(request: web.Request) -> web.Response: return web.Response(body="success!") app.router.add_get("/foobar", get_foobar) client = await aiohttp_client(app) async with await client.get("/foobar") as resp: assert resp.status == 200 assert (await resp.text()) == "success!" @pytest.mark.parametrize( ("route_definition", "urlencoded_path", "expected_http_resp_status"), ( ("/467,802,24834/hello", "/467%2C802%2C24834/hello", 200), ("/{user_ids:([0-9]+)(,([0-9]+))*}/hello", "/467%2C802%2C24834/hello", 200), ("/467,802,24834/hello", "/467,802,24834/hello", 200), ("/{user_ids:([0-9]+)(,([0-9]+))*}/hello", "/467,802,24834/hello", 200), ("/1%2C3/hello", "/1%2C3/hello", 404), ), ) async def test_decoded_url_match( aiohttp_client: AiohttpClient, route_definition: str, urlencoded_path: str, expected_http_resp_status: int, ) -> None: app = web.Application() async def handler(request: web.Request) -> web.Response: return web.Response() app.router.add_get(route_definition, handler) client = await aiohttp_client(app) async with client.get(yarl.URL(urlencoded_path, encoded=True)) as resp: assert resp.status == expected_http_resp_status async def test_decoded_raw_match_regex(aiohttp_client: AiohttpClient) -> None: """Verify that raw_match only matches decoded url.""" app = web.Application() async def handler(request: web.Request) -> NoReturn: assert False app.router.add_get("/467%2C802%2C24834%2C24952%2C25362%2C40574/hello", handler) client = await aiohttp_client(app) async with client.get( yarl.URL("/467%2C802%2C24834%2C24952%2C25362%2C40574/hello", encoded=True) ) as resp: assert resp.status == 404 # should only match decoded url async def test_order_is_preserved(aiohttp_client: AiohttpClient) -> None: """Test route order is preserved. Note that fixed/static paths are always preferred over a regex path. """ app = web.Application() async def handler(request: web.Request) -> web.Response: assert isinstance(request.match_info._route.resource, Resource) return web.Response(text=request.match_info._route.resource.canonical) app.router.add_get("/first/x/{b}/", handler) app.router.add_get(r"/first/{x:.*/b}", handler) app.router.add_get(r"/second/{user}/info", handler) app.router.add_get("/second/bob/info", handler) app.router.add_get("/third/bob/info", handler) app.router.add_get(r"/third/{user}/info", handler) app.router.add_get(r"/forth/{name:\d+}", handler) app.router.add_get("/forth/42", handler) app.router.add_get("/fifth/42", handler) app.router.add_get(r"/fifth/{name:\d+}", handler) client = await aiohttp_client(app) r = await client.get("/first/x/b/") assert r.status == 200 assert await r.text() == "/first/x/{b}/" r = await client.get("/second/frank/info") assert r.status == 200 assert await r.text() == "/second/{user}/info" # Fixed/static paths are always preferred over regex paths r = await client.get("/second/bob/info") assert r.status == 200 assert await r.text() == "/second/bob/info" r = await client.get("/third/bob/info") assert r.status == 200 assert await r.text() == "/third/bob/info" r = await client.get("/third/frank/info") assert r.status == 200 assert await r.text() == "/third/{user}/info" r = await client.get("/forth/21") assert r.status == 200 assert await r.text() == "/forth/{name}" # Fixed/static paths are always preferred over regex paths r = await client.get("/forth/42") assert r.status == 200 assert await r.text() == "/forth/42" r = await client.get("/fifth/21") assert r.status == 200 assert await r.text() == "/fifth/{name}" r = await client.get("/fifth/42") assert r.status == 200 assert await r.text() == "/fifth/42" async def test_url_with_many_slashes(aiohttp_client: AiohttpClient) -> None: app = web.Application() class MyView(web.View): async def get(self) -> web.Response: return web.Response() app.router.add_routes([web.view("/a", MyView)]) client = await aiohttp_client(app) async with client.get("///a") as r: assert r.status == 200 async def test_subapp_domain_routing_same_path(aiohttp_client: AiohttpClient) -> None: """Regression test for #11665.""" app = web.Application() sub_app = web.Application() async def mainapp_handler(request: web.Request) -> web.Response: assert False async def subapp_handler(request: web.Request) -> web.Response: return web.Response(text="SUBAPP") app.router.add_get("/", mainapp_handler) sub_app.router.add_get("/", subapp_handler) app.add_domain("different.example.com", sub_app) client = await aiohttp_client(app) async with client.get("/", headers={"Host": "different.example.com"}) as r: assert r.status == 200 result = await r.text() assert result == "SUBAPP" async def test_route_with_regex(aiohttp_client: AiohttpClient) -> None: """Test a route with a regex preceded by a fixed string.""" app = web.Application() async def handler(request: web.Request) -> web.Response: assert isinstance(request.match_info._route.resource, Resource) return web.Response(text=request.match_info._route.resource.canonical) app.router.add_get("/core/locations{tail:.*}", handler) client = await aiohttp_client(app) r = await client.get("/core/locations/tail/here") assert r.status == 200 assert await r.text() == "/core/locations{tail}" r = await client.get("/core/locations_tail_here") assert r.status == 200 assert await r.text() == "/core/locations{tail}" r = await client.get("/core/locations_tail;id=abcdef") assert r.status == 200 assert await r.text() == "/core/locations{tail}" ================================================ FILE: tests/test_web_websocket.py ================================================ import asyncio import json import time from typing import Protocol from unittest import mock import aiosignal import pytest from multidict import CIMultiDict from pytest_mock import MockerFixture from aiohttp import WSMessageTypeError, WSMsgType, web from aiohttp.http import WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE from aiohttp.http_websocket import WSMessageClose from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_request from aiohttp.web_ws import WebSocketReady class _RequestMaker(Protocol): def __call__( self, method: str, path: str, headers: CIMultiDict[str] | None = None, protocols: bool = False, ) -> web.Request: ... @pytest.fixture def app(loop: asyncio.AbstractEventLoop) -> web.Application: ret: web.Application = mock.create_autospec(web.Application, spec_set=True) ret.on_response_prepare = aiosignal.Signal(ret) # type: ignore[misc] ret.on_response_prepare.freeze() return ret @pytest.fixture def protocol() -> web.RequestHandler[web.Request]: ret = mock.Mock() ret.set_parser.return_value = ret ret._timeout_ceil_threshold = 5 return ret @pytest.fixture def make_request( app: web.Application, protocol: web.RequestHandler[web.Request] ) -> _RequestMaker: def maker( method: str, path: str, headers: CIMultiDict[str] | None = None, protocols: bool = False, ) -> web.Request: if headers is None: headers = CIMultiDict( { "HOST": "server.example.com", "UPGRADE": "websocket", "CONNECTION": "Upgrade", "SEC-WEBSOCKET-KEY": "dGhlIHNhbXBsZSBub25jZQ==", "ORIGIN": "http://example.com", "SEC-WEBSOCKET-VERSION": "13", } ) if protocols: headers["SEC-WEBSOCKET-PROTOCOL"] = "chat, superchat" return make_mocked_request(method, path, headers, app=app, protocol=protocol) return maker async def test_nonstarted_ping() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.ping() async def test_nonstarted_pong() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.pong() async def test_nonstarted_send_frame() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_frame(b"string", WSMsgType.TEXT) async def test_nonstarted_send_str() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_str("string") async def test_nonstarted_send_bytes() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_bytes(b"bytes") async def test_nonstarted_send_json() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_json({"type": "json"}) async def test_nonstarted_close() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.close() async def test_nonstarted_receive_str() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_str() async def test_cancel_heartbeat_cancels_pending_heartbeat_reset_handle( loop: asyncio.AbstractEventLoop, ) -> None: ws = web.WebSocketResponse(heartbeat=0.05) ws._loop = loop ws._on_data_received() handle = ws._heartbeat_reset_handle assert handle is not None ws._cancel_heartbeat() assert ws._heartbeat_reset_handle is None assert ws._need_heartbeat_reset is False assert handle.cancelled() async def test_flush_heartbeat_reset_returns_early_when_not_needed() -> None: ws = web.WebSocketResponse(heartbeat=0.05) ws._need_heartbeat_reset = False with mock.patch.object(ws, "_reset_heartbeat") as reset: ws._flush_heartbeat_reset() reset.assert_not_called() async def test_send_heartbeat_returns_early_when_reset_is_pending() -> None: ws = web.WebSocketResponse(heartbeat=0.05) ws._need_heartbeat_reset = True ws._send_heartbeat() assert ws._pong_response_cb is None assert ws._ping_task is None async def test_nonstarted_receive_bytes() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_bytes() async def test_nonstarted_receive_json() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_json() async def test_send_str_nonstring(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): await ws.send_str(b"bytes") # type: ignore[arg-type] async def test_send_bytes_nonbytes(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): await ws.send_bytes("string") # type: ignore[arg-type] async def test_send_json_nonjson(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): await ws.send_json(set()) async def test_nonstarted_send_json_bytes() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_json_bytes( {"type": "json"}, dumps=lambda x: json.dumps(x).encode("utf-8") ) async def test_send_json_bytes_nonjson(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): await ws.send_json_bytes(set(), dumps=lambda x: json.dumps(x).encode("utf-8")) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() async def test_write_non_prepared() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.write(b"data") async def test_heartbeat_timeout(make_request: _RequestMaker) -> None: """Verify the transport is closed when the heartbeat timeout is reached.""" loop = asyncio.get_running_loop() future = loop.create_future() req = make_request("GET", "/") assert req.transport is not None req.transport.close.side_effect = lambda: future.set_result(None) # type: ignore[attr-defined] lowest_time = time.get_clock_info("monotonic").resolution req._protocol._timeout_ceil_threshold = lowest_time ws = web.WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time) await ws.prepare(req) await future assert ws.closed async def test_heartbeat_reset_coalesces_on_data( make_request: _RequestMaker, ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(req) with mock.patch.object(ws, "_reset_heartbeat") as reset: ws._on_data_received() ws._on_data_received() await asyncio.sleep(0) assert reset.call_count == 1 async def test_receive_does_not_reset_heartbeat() -> None: ws = web.WebSocketResponse(heartbeat=0.05) msg = mock.Mock(type=WSMsgType.TEXT) reader = mock.Mock() reader.read = mock.AsyncMock(return_value=msg) ws._reader = reader with mock.patch.object(ws, "_reset_heartbeat") as reset: received = await ws.receive() assert received is msg reset.assert_not_called() def test_websocket_ready() -> None: websocket_ready = WebSocketReady(True, "chat") assert websocket_ready.ok is True assert websocket_ready.protocol == "chat" def test_websocket_not_ready() -> None: websocket_ready = WebSocketReady(False, None) assert websocket_ready.ok is False assert websocket_ready.protocol is None def test_websocket_ready_unknown_protocol() -> None: websocket_ready = WebSocketReady(True, None) assert websocket_ready.ok is True assert websocket_ready.protocol is None def test_bool_websocket_ready() -> None: websocket_ready = WebSocketReady(True, None) assert bool(websocket_ready) is True def test_bool_websocket_not_ready() -> None: websocket_ready = WebSocketReady(False, None) assert bool(websocket_ready) is False def test_can_prepare_ok(make_request: _RequestMaker) -> None: req = make_request("GET", "/", protocols=True) ws = web.WebSocketResponse(protocols=("chat",)) assert WebSocketReady(True, "chat") == ws.can_prepare(req) def test_can_prepare_unknown_protocol(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() assert WebSocketReady(True, None) == ws.can_prepare(req) def test_can_prepare_without_upgrade(make_request: _RequestMaker) -> None: req = make_request("GET", "/", headers=CIMultiDict({})) ws = web.WebSocketResponse() assert WebSocketReady(False, None) == ws.can_prepare(req) async def test_can_prepare_started(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(RuntimeError) as ctx: ws.can_prepare(req) assert "Already started" in str(ctx.value) def test_closed_after_ctor() -> None: ws = web.WebSocketResponse() assert not ws.closed assert ws.close_code is None async def test_raise_writer_limit(make_request: _RequestMaker) -> None: """Test the writer limit can be adjusted.""" req = make_request("GET", "/") ws = web.WebSocketResponse(writer_limit=1234567) await ws.prepare(req) assert ws._reader is not None assert ws._writer is not None assert ws._writer._limit == 1234567 ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() async def test_send_str_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() assert req.transport is not None assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] with pytest.raises(ConnectionError): await ws.send_str("string") async def test_recv_str_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() with pytest.raises( WSMessageTypeError, match=f"Received message {WSMsgType.CLOSED}:.+ is not WSMsgType.TEXT", ): await ws.receive_str() async def test_send_bytes_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() with pytest.raises(ConnectionError): await ws.send_bytes(b"bytes") async def test_recv_bytes_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() with pytest.raises( WSMessageTypeError, match=f"Received message {WSMsgType.CLOSED}:.+ is not WSMsgType.BINARY", ): await ws.receive_bytes() async def test_send_json_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() with pytest.raises(ConnectionError): await ws.send_json({"type": "json"}) async def test_send_json_bytes_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() with pytest.raises(ConnectionError): await ws.send_json_bytes( {"type": "json"}, dumps=lambda x: json.dumps(x).encode("utf-8") ) async def test_send_frame_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() with pytest.raises(ConnectionError): await ws.send_frame(b'{"type": "json"}', WSMsgType.TEXT) async def test_ping_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() with pytest.raises(ConnectionError): await ws.ping() async def test_pong_closed(make_request: _RequestMaker, mocker: MockerFixture) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() with pytest.raises(ConnectionError): await ws.pong() async def test_close_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) close_code = await ws.close(code=1, message=b"message1") assert close_code == 1 assert ws.closed assert req.transport is not None assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] close_code = await ws.close(code=2, message=b"message2") assert close_code == 0 async def test_prepare_post_method_ok(make_request: _RequestMaker) -> None: req = make_request("POST", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws.prepared async def test_prepare_without_upgrade(make_request: _RequestMaker) -> None: req = make_request("GET", "/", headers=CIMultiDict({})) ws = web.WebSocketResponse() with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_wait_closed_before_start() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.close() async def test_write_eof_not_started() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.write_eof() async def test_write_eof_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert req.transport is not None assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() await ws.write_eof() await ws.write_eof() await ws.write_eof() assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] async def test_receive_eofstream_in_reader( make_request: _RequestMaker, loop: asyncio.AbstractEventLoop ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) ws._reader = mock.Mock() exc = EofStream() ws._reader.read = mock.AsyncMock(side_effect=exc) assert ws._payload_writer is not None f = loop.create_future() f.set_result(True) ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] msg = await ws.receive() assert msg.type == WSMsgType.CLOSED assert ws.closed async def test_receive_exception_in_reader( make_request: _RequestMaker, loop: asyncio.AbstractEventLoop ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) ws._reader = mock.Mock() exc = Exception() ws._reader.read = mock.AsyncMock(side_effect=exc) f = loop.create_future() assert ws._payload_writer is not None ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] f.set_result(True) msg = await ws.receive() assert msg.type == WSMsgType.ERROR assert ws.closed assert req.transport is not None assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] async def test_receive_close_but_left_open( make_request: _RequestMaker, loop: asyncio.AbstractEventLoop ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) close_message = WSMessageClose(data=1000, size=0, extra="close") ws._reader = mock.Mock() ws._reader.read = mock.AsyncMock(return_value=close_message) f = loop.create_future() assert ws._payload_writer is not None ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] f.set_result(True) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert ws.closed assert req.transport is not None assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] async def test_receive_closing( make_request: _RequestMaker, loop: asyncio.AbstractEventLoop ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) closing_message = WS_CLOSING_MESSAGE ws._reader = mock.Mock() read_mock = mock.AsyncMock(return_value=closing_message) ws._reader.read = read_mock f = loop.create_future() assert ws._payload_writer is not None ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] f.set_result(True) msg = await ws.receive() assert msg.type == WSMsgType.CLOSING assert not ws.closed msg = await ws.receive() assert msg.type == WSMsgType.CLOSING assert not ws.closed ws._cancel(ConnectionResetError("Connection lost")) msg = await ws.receive() assert msg.type == WSMsgType.CLOSING async def test_close_after_closing( make_request: _RequestMaker, loop: asyncio.AbstractEventLoop ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) closing_message = WS_CLOSING_MESSAGE ws._reader = mock.Mock() ws._reader.read = mock.AsyncMock(return_value=closing_message) f = loop.create_future() assert ws._payload_writer is not None ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] f.set_result(True) msg = await ws.receive() assert msg.type == WSMsgType.CLOSING assert not ws.closed assert req.transport is not None assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] await ws.close() assert ws.closed assert len(req.transport.close.mock_calls) == 1 # type: ignore[unreachable] async def test_receive_timeouterror( make_request: _RequestMaker, loop: asyncio.AbstractEventLoop ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert req.transport is not None assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] ws._reader = mock.Mock() ws._reader.read = mock.AsyncMock(side_effect=asyncio.TimeoutError()) with pytest.raises(asyncio.TimeoutError): await ws.receive() # Should not close the connection on timeout assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] async def test_multiple_receive_on_close_connection( make_request: _RequestMaker, ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() await ws.receive() await ws.receive() await ws.receive() await ws.receive() with pytest.raises(RuntimeError): await ws.receive() async def test_concurrent_receive(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) ws._waiting = True with pytest.raises(RuntimeError): await ws.receive() async def test_close_exc(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) assert req.transport is not None assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] exc = ValueError() ws._writer = mock.Mock() ws._writer.close.side_effect = exc await ws.close() assert ws.closed assert ws.exception() is exc assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] ws._closed = False ws._writer.close.side_effect = asyncio.CancelledError() with pytest.raises(asyncio.CancelledError): await ws.close() async def test_prepare_twice_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() impl1 = await ws.prepare(req) impl2 = await ws.prepare(req) assert impl1 is impl2 async def test_send_with_per_message_deflate( make_request: _RequestMaker, mocker: MockerFixture ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) with mock.patch.object(ws._writer, "send_frame", autospec=True, spec_set=True) as m: await ws.send_str("string", compress=15) m.assert_called_with(b"string", WSMsgType.TEXT, compress=15) await ws.send_bytes(b"bytes", compress=0) m.assert_called_with(b"bytes", WSMsgType.BINARY, compress=0) await ws.send_json("[{}]", compress=9) m.assert_called_with(b'"[{}]"', WSMsgType.TEXT, compress=9) await ws.send_frame(b"[{}]", WSMsgType.TEXT, compress=9) m.assert_called_with(b"[{}]", WSMsgType.TEXT, compress=9) async def test_no_transfer_encoding_header( make_request: _RequestMaker, mocker: MockerFixture ) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws._start(req) assert "Transfer-Encoding" not in ws.headers @pytest.mark.parametrize( "ws_transport, expected_result", [ ( mock.MagicMock( transport=mock.MagicMock( get_extra_info=lambda name, default=None: {"test": "existent"}.get( name, default ) ) ), "existent", ), (None, "default"), ], ) async def test_get_extra_info( make_request: _RequestMaker, mocker: MockerFixture, ws_transport: mock.MagicMock | None, expected_result: str, ) -> None: valid_key = "test" default_value = "default" req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) ws._writer = ws_transport assert expected_result == ws.get_extra_info(valid_key, default_value) ================================================ FILE: tests/test_web_websocket_functional.py ================================================ # HTTP websocket server functional tests import asyncio import contextlib import json import sys import weakref from typing import Literal, NoReturn from unittest import mock import pytest import aiohttp from aiohttp import WSServerHandshakeError, web from aiohttp.http import WSCloseCode, WSMsgType from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer async def test_websocket_can_prepare( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() assert not ws.can_prepare(request) raise web.HTTPUpgradeRequired() app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 426 async def test_websocket_json( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() assert ws.can_prepare(request) await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.TEXT msg_json = msg.json() answer = msg_json["test"] await ws.send_str(answer) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") expected_value = "value" payload = '{"test": "%s"}' % expected_value await ws.send_str(payload) resp = await ws.receive() assert resp.data == expected_value await ws.receive() # Handle close async def test_websocket_json_invalid_message( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) with pytest.raises(ValueError): await ws.receive_json() await ws.send_str("ValueError was raised") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") payload = "NOT A VALID JSON STRING" await ws.send_str(payload) data = await ws.receive_str() assert "ValueError was raised" in data await ws.receive() # Handle close async def test_websocket_send_json( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) data = await ws.receive_json() await ws.send_json(data) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") expected_value = "value" await ws.send_json({"test": expected_value}) data = await ws.receive_json() assert data["test"] == expected_value await ws.receive() # Handle close async def test_websocket_receive_json( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) data = await ws.receive_json() answer = data["test"] await ws.send_str(answer) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") expected_value = "value" payload = '{"test": "%s"}' % expected_value await ws.send_str(payload) resp = await ws.receive() assert resp.data == expected_value await ws.receive() # Handle close async def test_send_recv_text( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() await ws.send_str(msg + "/answer") await ws.close() closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.send_str("ask") msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.TEXT assert "ask/answer" == msg.data msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.data == WSCloseCode.OK assert msg.extra == "" assert ws.closed assert ws.close_code == WSCloseCode.OK await closed async def test_send_recv_bytes( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_bytes() await ws.send_bytes(msg + b"/answer") await ws.close() closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.send_bytes(b"ask") msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.BINARY assert b"ask/answer" == msg.data msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.data == WSCloseCode.OK assert msg.extra == "" assert ws.closed assert ws.close_code == WSCloseCode.OK await closed async def test_send_recv_json( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) data = await ws.receive_json() await ws.send_json({"response": data["request"]}) await ws.close() closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.send_str('{"request": "test"}') msg = await ws.receive() assert msg.type is WSMsgType.TEXT data = msg.json() assert msg.type == aiohttp.WSMsgType.TEXT assert data["response"] == "test" msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.CLOSE assert msg.data == WSCloseCode.OK assert msg.extra == "" await ws.close() await closed async def test_close_timeout( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: aborted = loop.create_future() elapsed = 1e10 # something big async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal elapsed ws = web.WebSocketResponse(timeout=0.1) await ws.prepare(request) assert "request" == (await ws.receive_str()) await ws.send_str("reply") assert ws._loop is not None begin = ws._loop.time() assert await ws.close() elapsed = ws._loop.time() - begin assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE assert isinstance(ws.exception(), asyncio.TimeoutError) aborted.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.send_str("request") assert "reply" == (await ws.receive_str()) # The server closes here. Then the client sends bogus messages with an # interval shorter than server-side close timeout, to make the server # hanging indefinitely. await asyncio.sleep(0.08) msg = await ws._reader.read() assert msg.type == WSMsgType.CLOSE await asyncio.sleep(0.08) assert await aborted assert elapsed < 0.25, "close() should have returned before at most 2x timeout." await ws.close() async def test_concurrent_close( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: srv_ws = None async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) msg = await ws.receive() assert msg.type == WSMsgType.CLOSING await asyncio.sleep(0) msg = await ws.receive() assert msg.type == WSMsgType.CLOSED return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) assert srv_ws is not None await srv_ws.close(code=WSCloseCode.INVALID_TEXT) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE await asyncio.sleep(0) msg = await ws.receive() assert msg.type == WSMsgType.CLOSED async def test_concurrent_close_multiple_tasks( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: srv_ws = None async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) msg = await ws.receive() assert msg.type == WSMsgType.CLOSING await asyncio.sleep(0) msg = await ws.receive() assert msg.type == WSMsgType.CLOSED return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) assert srv_ws is not None task1 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) task2 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE await task1 await task2 await asyncio.sleep(0) msg = await ws.receive() assert msg.type == WSMsgType.CLOSED async def test_close_op_code_from_client( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: srv_ws: web.WebSocketResponse | None = None async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE await asyncio.sleep(0) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", protocols=("eggs", "bar")) await ws._writer.send_frame(b"", WSMsgType.CLOSE) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE await asyncio.sleep(0) msg = await ws.receive() assert msg.type == WSMsgType.CLOSED async def test_auto_pong_with_closing_by_peer( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert msg.data == WSCloseCode.OK assert msg.extra == "exit message" closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoclose=False, autoping=False) await ws.ping() await ws.send_str("ask") msg = await ws.receive() assert msg.type == WSMsgType.PONG await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed async def test_ping( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.ping(b"data") await ws.receive() closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) msg = await ws.receive() assert msg.type == WSMsgType.PING assert msg.data == b"data" await ws.pong() await ws.close() await closed async def test_client_ping( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" await ws.pong() await ws.close() async def test_pong( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) msg = await ws.receive() assert msg.type == WSMsgType.PING await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert msg.data == WSCloseCode.OK assert msg.extra == "exit message" closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed async def test_change_status( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() ws.set_status(200) assert 200 == ws.status await ws.prepare(request) assert 101 == ws.status await ws.close() closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) await ws.close() await closed await ws.close() async def test_handle_protocol( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() assert "bar" == ws.ws_protocol closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", protocols=("eggs", "bar")) await ws.close() await closed async def test_server_close_handshake( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE await ws.close() await closed async def test_client_close_handshake( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert not ws.closed await ws.close() assert ws.closed assert ws.close_code == WSCloseCode.INVALID_TEXT # type: ignore[unreachable] msg = await ws.receive() assert msg.type == WSMsgType.CLOSED closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) await ws.close(code=WSCloseCode.INVALID_TEXT) msg = await ws.receive() assert msg.type == WSMsgType.CLOSED await closed async def test_server_close_handshake_server_eats_client_messages( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect( "/", autoclose=False, autoping=False, protocols=("eggs", "bar") ) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE await ws.send_str("text") await ws.send_bytes(b"bytes") await ws.ping() await ws.close() await closed async def test_receive_timeout( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: raised = False async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(receive_timeout=0.1) await ws.prepare(request) try: await ws.receive() except asyncio.TimeoutError: nonlocal raised raised = True await ws.close() return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.receive() await ws.close() assert raised async def test_custom_receive_timeout( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: raised = False async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(receive_timeout=None) await ws.prepare(request) try: await ws.receive(0.1) except asyncio.TimeoutError: nonlocal raised raised = True await ws.close() return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.receive() await ws.close() assert raised async def test_heartbeat( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) await ws.receive() await ws.close() return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.PING await ws.close() async def test_heartbeat_no_pong( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) await ws.receive() return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) msg = await ws.receive() assert msg.type == aiohttp.WSMsgType.PING await ws.close() async def test_heartbeat_connection_closed( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test that the connection is closed while ping is in progress.""" ping_count = 0 async def handler(request: web.Request) -> NoReturn: nonlocal ping_count ws_server = web.WebSocketResponse(heartbeat=0.05) await ws_server.prepare(request) # We patch write here to simulate a connection reset error # since if we closed the connection normally, the server would # would cancel the heartbeat task and we wouldn't get a ping assert ws_server._req is not None assert ws_server._writer is not None with ( mock.patch.object( ws_server._req.transport, "write", side_effect=ConnectionResetError ), mock.patch.object( ws_server._writer, "send_frame", wraps=ws_server._writer.send_frame ) as send_frame, ): try: await ws_server.receive() finally: ping_count = send_frame.call_args_list.count( mock.call(b"", WSMsgType.PING) ) assert False app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.CLOSED assert msg.extra is None assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE assert ping_count == 1 await ws.close() async def test_heartbeat_failure_ends_receive( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test that no heartbeat response to the server ends the receive call.""" ws_server_close_code = None ws_server_exception = None async def handler(request: web.Request) -> NoReturn: nonlocal ws_server_close_code, ws_server_exception ws_server = web.WebSocketResponse(heartbeat=0.05) await ws_server.prepare(request) try: await ws_server.receive() finally: ws_server_close_code = ws_server.close_code ws_server_exception = ws_server.exception() assert False app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.PING msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.CLOSED assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE assert ws_server_close_code == WSCloseCode.ABNORMAL_CLOSURE assert isinstance(ws_server_exception, asyncio.TimeoutError) assert str(ws_server_exception) == "No PONG received after 0.025 seconds" await ws.close() async def test_heartbeat_no_pong_send_many_messages( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test no pong after sending many messages.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) for _ in range(10): await ws.send_str("test") await ws.receive() return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) for _ in range(10): msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.TEXT assert msg.data == "test" msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.PING await ws.close() async def test_heartbeat_no_pong_receive_many_messages( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test no pong after receiving many messages.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) for _ in range(10): server_msg = await ws.receive() assert server_msg.type is aiohttp.WSMsgType.TEXT await ws.receive() return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) for _ in range(10): await ws.send_str("test") msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.PING await ws.close() async def test_server_ws_async_for( loop: asyncio.AbstractEventLoop, aiohttp_server: AiohttpServer ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: assert msg.type == aiohttp.WSMsgType.TEXT s = msg.data await ws.send_str(s + "/answer") await ws.close() closed.set_result(1) return ws app = web.Application() app.router.add_route("GET", "/", handler) server = await aiohttp_server(app) async with aiohttp.ClientSession() as sm: async with sm.ws_connect(server.make_url("/")) as resp: items = ["q1", "q2", "q3"] for item in items: await resp.send_str(item) msg = await resp.receive() assert msg.type == aiohttp.WSMsgType.TEXT assert item + "/answer" == msg.data await resp.close() await closed async def test_closed_async_for( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) messages = [] async for msg in ws: messages.append(msg) assert "stop" == msg.data await ws.send_str("stopping") await ws.close() assert 1 == len(messages) assert messages[0].type == WSMsgType.TEXT assert messages[0].data == "stop" closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.send_str("stop") msg = await ws.receive() assert msg.type == WSMsgType.TEXT assert msg.data == "stopping" await ws.close() await closed async def test_websocket_disable_keepalive( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.StreamResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): return web.Response(text="OK") assert request.protocol._keepalive await ws.prepare(request) assert not request.protocol._keepalive assert not request.protocol._keepalive_handle # type: ignore[unreachable] await ws.send_str("OK") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.get("/") txt = await resp.text() assert txt == "OK" ws = await client.ws_connect("/") data = await ws.receive_str() assert data == "OK" await ws.receive() # Handle close async def test_receive_str_nonstring( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() assert ws.can_prepare(request) await ws.prepare(request) await ws.send_bytes(b"answer") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") with pytest.raises(TypeError): await ws.receive_str() await ws.receive() # Handle close async def test_receive_bytes_nonbytes( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() assert ws.can_prepare(request) await ws.prepare(request) await ws.send_str("answer") assert False app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") with pytest.raises(TypeError): await ws.receive_bytes() async def test_bug3380( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: async def handle_null(request: web.Request) -> web.Response: return web.json_response({"err": None}) async def ws_handler(request: web.Request) -> web.Response: return web.Response(status=401) app = web.Application() app.router.add_route("GET", "/ws", ws_handler) app.router.add_route("GET", "/api/null", handle_null) client = await aiohttp_client(app) resp = await client.get("/api/null") assert (await resp.json()) == {"err": None} resp.close() with pytest.raises(WSServerHandshakeError): await client.ws_connect("/ws") resp = await client.get("/api/null", timeout=aiohttp.ClientTimeout(total=1)) assert (await resp.json()) == {"err": None} resp.close() async def test_receive_being_cancelled_keeps_connection_open( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) task = asyncio.create_task(ws.receive()) await asyncio.sleep(0) task.cancel() with contextlib.suppress(asyncio.CancelledError): await task msg = await ws.receive() assert msg.type == WSMsgType.PING await asyncio.sleep(0) await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert msg.data == WSCloseCode.OK assert msg.extra == "exit message" closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) await asyncio.sleep(0) await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed async def test_receive_timeout_keeps_connection_open( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() timed_out = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) task = asyncio.create_task(ws.receive(sys.float_info.min)) with contextlib.suppress(asyncio.TimeoutError): await task timed_out.set_result(None) msg = await ws.receive() assert msg.type == WSMsgType.PING await asyncio.sleep(0) await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert msg.data == WSCloseCode.OK assert msg.extra == "exit message" closed.set_result(None) return ws app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/", autoping=False) await timed_out await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed async def test_websocket_shutdown(aiohttp_client: AiohttpClient) -> None: """Test that the client websocket gets the close message when the server is shutting down.""" url = "/ws" app = web.Application() websockets = web.AppKey("websockets", weakref.WeakSet[web.WebSocketResponse]) app[websockets] = weakref.WeakSet() # need for send signal shutdown server shutdown_websockets = web.AppKey( "shutdown_websockets", weakref.WeakSet[web.WebSocketResponse] ) app[shutdown_websockets] = weakref.WeakSet() async def websocket_handler(request: web.Request) -> web.WebSocketResponse: websocket = web.WebSocketResponse() await websocket.prepare(request) request.app[websockets].add(websocket) request.app[shutdown_websockets].add(websocket) try: async for message in websocket: assert message.type is WSMsgType.TEXT await websocket.send_json({"ok": True, "message": message.json()}) finally: request.app[websockets].discard(websocket) return websocket async def on_shutdown(app: web.Application) -> None: while app[shutdown_websockets]: websocket = app[shutdown_websockets].pop() await websocket.close( code=aiohttp.WSCloseCode.GOING_AWAY, message=b"Server shutdown", ) app.router.add_get(url, websocket_handler) app.on_shutdown.append(on_shutdown) client = await aiohttp_client(app) websocket = await client.ws_connect(url) message = {"message": "hi"} await websocket.send_json(message) reply = await websocket.receive_json() assert reply == {"ok": True, "message": message} await app.shutdown() assert websocket.closed is False reply = await websocket.receive() assert reply.type is aiohttp.http.WSMsgType.CLOSE assert reply.data == aiohttp.WSCloseCode.GOING_AWAY assert reply.extra == "Server shutdown" assert websocket.closed is True async def test_ws_close_return_code(aiohttp_client: AiohttpClient) -> None: """Test that the close code is returned when the server closes the connection.""" async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("some data") msg = await resp.receive() assert msg.type is aiohttp.WSMsgType.CLOSE assert resp.close_code == WSCloseCode.OK async def test_abnormal_closure_when_server_does_not_receive( aiohttp_client: AiohttpClient, ) -> None: """Test abnormal closure when the server closes and a message is pending.""" async def handler(request: web.Request) -> web.WebSocketResponse: # Setting close timeout to 0, otherwise the server waits for a # close response for 10 seconds by default. # This would make the client's autoclose in resp.receive() to succeed, # closing the connection cleanly from both sides. ws = web.WebSocketResponse(timeout=0) await ws.prepare(request) await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) resp = await client.ws_connect("/") await resp.send_str("some data") await asyncio.sleep(0.1) msg = await resp.receive() assert msg.type is aiohttp.WSMsgType.CLOSE assert resp.close_code == WSCloseCode.ABNORMAL_CLOSURE async def test_abnormal_closure_when_client_does_not_close( aiohttp_client: AiohttpClient, ) -> None: """Test abnormal closure when the server closes and the client doesn't respond.""" close_code: WSCloseCode | None = None async def handler(request: web.Request) -> web.WebSocketResponse: # Setting a short close timeout ws = web.WebSocketResponse(timeout=0.1) await ws.prepare(request) await ws.close() nonlocal close_code assert ws.close_code is not None close_code = WSCloseCode(ws.close_code) return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.ws_connect("/", autoclose=False): await asyncio.sleep(0.2) await client.server.close() assert close_code == WSCloseCode.ABNORMAL_CLOSURE async def test_normal_closure_while_client_sends_msg( aiohttp_client: AiohttpClient, ) -> None: """Test normal closure when the server closes and the client responds properly.""" close_code: WSCloseCode | None = None got_close_code = asyncio.Event() async def handler(request: web.Request) -> web.WebSocketResponse: # Setting a longer close timeout to avoid race conditions ws = web.WebSocketResponse(timeout=1.0) await ws.prepare(request) await ws.close() nonlocal close_code assert ws.close_code is not None close_code = WSCloseCode(ws.close_code) got_close_code.set() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) async with client.ws_connect("/", autoclose=False) as ws: # send text and close message during server close timeout await asyncio.sleep(0.1) await ws.send_str("Hello") await ws.close() # wait for close code to be received by server await asyncio.wait( [ asyncio.create_task(asyncio.sleep(0.5)), asyncio.create_task(got_close_code.wait()), ], return_when=asyncio.FIRST_COMPLETED, ) await client.server.close() assert close_code == WSCloseCode.OK async def test_websocket_prepare_timeout_close_issue( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test that WebSocket can handle prepare with early returns. This is a regression test for issue #6009 where the prepared property incorrectly checked _payload_writer instead of _writer. """ async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() assert ws.can_prepare(request) await ws.prepare(request) await ws.send_str("test") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/ws", handler) client = await aiohttp_client(app) # Connect via websocket ws = await client.ws_connect("/ws") msg = await ws.receive() assert msg.type is WSMsgType.TEXT assert msg.data == "test" await ws.close() async def test_websocket_prepare_timeout_from_issue_reproducer( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test websocket behavior when prepare is interrupted. This test verifies the fix for issue #6009 where close() would fail after prepare() was interrupted. """ prepare_complete = asyncio.Event() close_complete = asyncio.Event() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() # Prepare the websocket await ws.prepare(request) prepare_complete.set() # Send a message to confirm connection works await ws.send_str("connected") # Wait for client to close msg = await ws.receive() assert msg.type is WSMsgType.CLOSE await ws.close() close_complete.set() return ws app = web.Application() app.router.add_route("GET", "/ws", handler) client = await aiohttp_client(app) # Connect and verify the connection works ws = await client.ws_connect("/ws") await prepare_complete.wait() msg = await ws.receive() assert msg.type is WSMsgType.TEXT assert msg.data == "connected" # Close the connection await ws.close() await close_complete.wait() async def test_websocket_prepared_property( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test that WebSocketResponse.prepared property correctly reflects state.""" prepare_called = asyncio.Event() async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() # Initially not prepared initial_state = ws.prepared assert not initial_state # After prepare() is called, should be prepared await ws.prepare(request) prepare_called.set() # Check prepared state prepared_state = ws.prepared assert prepared_state # Send a message to verify the connection works await ws.send_str("test") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await prepare_called.wait() msg = await ws.receive() assert msg.type is WSMsgType.TEXT assert msg.data == "test" await ws.close() async def test_receive_text_as_bytes_server_side(aiohttp_client: AiohttpClient) -> None: """Test server receiving TEXT messages as raw bytes with decode_text=False.""" async def websocket_handler( request: web.Request, ) -> web.WebSocketResponse[Literal[False]]: ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( decode_text=False ) await ws.prepare(request) # Receive TEXT message as bytes msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.TEXT assert isinstance(msg.data, bytes) assert msg.data == b"test message" # Send response await ws.send_bytes(msg.data + b"/reply") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", websocket_handler) client = await aiohttp_client(app) async with client.ws_connect("/") as ws: await ws.send_str("test message") msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.BINARY assert msg.data == b"test message/reply" await ws.close() async def test_receive_text_as_bytes_server_iteration( aiohttp_client: AiohttpClient, ) -> None: """Test server iterating over WebSocket with decode_text=False.""" async def websocket_handler( request: web.Request, ) -> web.WebSocketResponse[Literal[False]]: ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( decode_text=False ) await ws.prepare(request) async for msg in ws: if msg.type is aiohttp.WSMsgType.TEXT: # msg.data should be bytes assert isinstance(msg.data, bytes) # Echo back await ws.send_bytes(msg.data) else: assert msg.type is aiohttp.WSMsgType.BINARY assert isinstance(msg.data, bytes) await ws.send_bytes(msg.data) return ws app = web.Application() app.router.add_route("GET", "/", websocket_handler) client = await aiohttp_client(app) async with client.ws_connect("/") as ws: # Send TEXT message await ws.send_str("hello") msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.BINARY assert msg.data == b"hello" # Send BINARY message await ws.send_bytes(b"world") msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.BINARY assert msg.data == b"world" await ws.close() async def test_server_decode_text_default_true(aiohttp_client: AiohttpClient) -> None: """Test that server decode_text defaults to True for backward compatibility.""" async def websocket_handler(request: web.Request) -> web.WebSocketResponse: # No decode_text parameter - should default to True ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.TEXT assert isinstance(msg.data, str) assert msg.data == "test" await ws.send_str(msg.data + "/reply") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", websocket_handler) client = await aiohttp_client(app) async with client.ws_connect("/") as ws: await ws.send_str("test") msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.TEXT assert isinstance(msg.data, str) assert msg.data == "test/reply" await ws.close() async def test_server_receive_str_returns_bytes_with_decode_text_false( aiohttp_client: AiohttpClient, ) -> None: """Test that server receive_str() returns bytes when decode_text=False.""" async def websocket_handler( request: web.Request, ) -> web.WebSocketResponse[Literal[False]]: ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( decode_text=False ) await ws.prepare(request) # receive_str() should return bytes when decode_text=False data = await ws.receive_str() assert isinstance(data, bytes) assert data == b"hello server" await ws.send_str("got bytes") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", websocket_handler) client = await aiohttp_client(app) async with client.ws_connect("/") as ws: await ws.send_str("hello server") msg = await ws.receive() assert msg.data == "got bytes" async def test_server_receive_str_returns_str_with_decode_text_true( aiohttp_client: AiohttpClient, ) -> None: """Test that server receive_str() returns str when decode_text=True (default).""" async def websocket_handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() # decode_text=True by default await ws.prepare(request) # receive_str() should return str when decode_text=True data = await ws.receive_str() assert isinstance(data, str) assert data == "hello server" await ws.send_str("got string") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", websocket_handler) client = await aiohttp_client(app) async with client.ws_connect("/") as ws: await ws.send_str("hello server") msg = await ws.receive() assert msg.data == "got string" async def test_server_receive_json_with_orjson_style_loads( aiohttp_client: AiohttpClient, ) -> None: """Test server receive_json() with orjson-style loads that accepts bytes.""" def orjson_style_loads(data: bytes) -> dict[str, str]: """Mock orjson.loads that accepts bytes.""" assert isinstance(data, bytes) result: dict[str, str] = json.loads(data) return result async def websocket_handler( request: web.Request, ) -> web.WebSocketResponse[Literal[False]]: ws: web.WebSocketResponse[Literal[False]] = web.WebSocketResponse( decode_text=False ) await ws.prepare(request) # receive_json() with orjson-style loads should work with bytes data = await ws.receive_json(loads=orjson_style_loads) assert data == {"test": "value"} await ws.send_str("success") await ws.close() return ws app = web.Application() app.router.add_route("GET", "/", websocket_handler) client = await aiohttp_client(app) ws = await client.ws_connect("/") await ws.send_str('{"test": "value"}') msg = await ws.receive() assert msg.type is aiohttp.WSMsgType.TEXT assert msg.data == "success" await ws.close() ================================================ FILE: tests/test_websocket_data_queue.py ================================================ import asyncio from unittest import mock import pytest from aiohttp._websocket.models import WSMessageBinary from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol @pytest.fixture def protocol() -> BaseProtocol: return mock.create_autospec(BaseProtocol, spec_set=True, instance=True, _reading_paused=False) # type: ignore[no-any-return] @pytest.fixture def buffer( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol ) -> WebSocketDataQueue: return WebSocketDataQueue(protocol, limit=1, loop=loop) class TestWebSocketDataQueue: def test_feed_pause(self, buffer: WebSocketDataQueue) -> None: buffer._protocol._reading_paused = False for _ in range(3): buffer.feed_data(WSMessageBinary(b"x", size=1)) assert buffer._protocol.pause_reading.called # type: ignore[attr-defined] async def test_resume_on_read(self, buffer: WebSocketDataQueue) -> None: buffer.feed_data(WSMessageBinary(b"x", size=1)) buffer._protocol._reading_paused = True await buffer.read() assert buffer._protocol.resume_reading.called # type: ignore[attr-defined] ================================================ FILE: tests/test_websocket_handshake.py ================================================ # Tests for http/websocket.py import base64 import os import pytest from aiohttp import web from aiohttp.test_utils import make_mocked_request def gen_ws_headers( protocols: str = "", compress: int = 0, extension_text: str = "", server_notakeover: bool = False, client_notakeover: bool = False, ) -> tuple[list[tuple[str, str]], str]: key = base64.b64encode(os.urandom(16)).decode() hdrs = [ ("Upgrade", "websocket"), ("Connection", "upgrade"), ("Sec-Websocket-Version", "13"), ("Sec-Websocket-Key", key), ] if protocols: hdrs += [("Sec-Websocket-Protocol", protocols)] if compress: params = "permessage-deflate" if compress < 15: params += "; server_max_window_bits=" + str(compress) if server_notakeover: params += "; server_no_context_takeover" if client_notakeover: params += "; client_no_context_takeover" if extension_text: params += "; " + extension_text hdrs += [("Sec-Websocket-Extensions", params)] return hdrs, key async def test_no_upgrade() -> None: ws = web.WebSocketResponse() req = make_mocked_request("GET", "/") with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_no_connection() -> None: ws = web.WebSocketResponse() req = make_mocked_request( "GET", "/", headers={"Upgrade": "websocket", "Connection": "keep-alive"} ) with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_protocol_version_unset() -> None: ws = web.WebSocketResponse() req = make_mocked_request( "GET", "/", headers={"Upgrade": "websocket", "Connection": "upgrade"} ) with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_protocol_version_not_supported() -> None: ws = web.WebSocketResponse() req = make_mocked_request( "GET", "/", headers={ "Upgrade": "websocket", "Connection": "upgrade", "Sec-Websocket-Version": "1", }, ) with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_protocol_key_not_present() -> None: ws = web.WebSocketResponse() req = make_mocked_request( "GET", "/", headers={ "Upgrade": "websocket", "Connection": "upgrade", "Sec-Websocket-Version": "13", }, ) with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_protocol_key_invalid() -> None: ws = web.WebSocketResponse() req = make_mocked_request( "GET", "/", headers={ "Upgrade": "websocket", "Connection": "upgrade", "Sec-Websocket-Version": "13", "Sec-Websocket-Key": "123", }, ) with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_protocol_key_bad_size() -> None: ws = web.WebSocketResponse() sec_key = base64.b64encode(os.urandom(2)) val = sec_key.decode() req = make_mocked_request( "GET", "/", headers={ "Upgrade": "websocket", "Connection": "upgrade", "Sec-Websocket-Version": "13", "Sec-Websocket-Key": val, }, ) with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_handshake_ok() -> None: hdrs, sec_key = gen_ws_headers() ws = web.WebSocketResponse() req = make_mocked_request("GET", "/", headers=hdrs) await ws.prepare(req) assert ws.ws_protocol is None async def test_handshake_protocol() -> None: # Tests if one protocol is returned by handshake proto = "chat" ws = web.WebSocketResponse(protocols={"chat"}) req = make_mocked_request("GET", "/", headers=gen_ws_headers(proto)[0]) await ws.prepare(req) assert ws.ws_protocol == proto async def test_handshake_protocol_agreement() -> None: # Tests if the right protocol is selected given multiple best_proto = "worse_proto" wanted_protos = ["best", "chat", "worse_proto"] server_protos = "worse_proto,chat" ws = web.WebSocketResponse(protocols=wanted_protos) req = make_mocked_request("GET", "/", headers=gen_ws_headers(server_protos)[0]) await ws.prepare(req) assert ws.ws_protocol == best_proto async def test_handshake_protocol_unsupported(caplog: pytest.LogCaptureFixture) -> None: # Tests if a protocol mismatch handshake warns and returns None proto = "chat" req = make_mocked_request("GET", "/", headers=gen_ws_headers("test")[0]) ws = web.WebSocketResponse(protocols=[proto]) await ws.prepare(req) assert ( caplog.records[-1].msg == "%s: Client protocols %r don’t overlap server-known ones %r" ) assert ws.ws_protocol is None async def test_handshake_compress() -> None: hdrs, sec_key = gen_ws_headers(compress=15) req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() await ws.prepare(req) assert ws.compress == 15 def test_handshake_compress_server_notakeover() -> None: hdrs, sec_key = gen_ws_headers(compress=15, server_notakeover=True) req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() headers, _, compress, notakeover = ws._handshake(req) assert compress == 15 assert notakeover is True assert "Sec-Websocket-Extensions" in headers assert headers["Sec-Websocket-Extensions"] == ( "permessage-deflate; server_no_context_takeover" ) def test_handshake_compress_client_notakeover() -> None: hdrs, sec_key = gen_ws_headers(compress=15, client_notakeover=True) req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() headers, _, compress, notakeover = ws._handshake(req) assert "Sec-Websocket-Extensions" in headers assert headers["Sec-Websocket-Extensions"] == ("permessage-deflate"), hdrs assert compress == 15 def test_handshake_compress_wbits() -> None: hdrs, sec_key = gen_ws_headers(compress=9) req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() headers, _, compress, notakeover = ws._handshake(req) assert "Sec-Websocket-Extensions" in headers assert headers["Sec-Websocket-Extensions"] == ( "permessage-deflate; server_max_window_bits=9" ) assert compress == 9 def test_handshake_compress_wbits_error() -> None: hdrs, sec_key = gen_ws_headers(compress=6) req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() headers, _, compress, notakeover = ws._handshake(req) assert "Sec-Websocket-Extensions" not in headers assert compress == 0 def test_handshake_compress_bad_ext() -> None: hdrs, sec_key = gen_ws_headers(compress=15, extension_text="bad") req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() headers, _, compress, notakeover = ws._handshake(req) assert "Sec-Websocket-Extensions" not in headers assert compress == 0 def test_handshake_compress_multi_ext_bad() -> None: hdrs, sec_key = gen_ws_headers( compress=15, extension_text="bad, permessage-deflate" ) req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() headers, _, compress, notakeover = ws._handshake(req) assert "Sec-Websocket-Extensions" in headers assert headers["Sec-Websocket-Extensions"] == "permessage-deflate" def test_handshake_compress_multi_ext_wbits() -> None: hdrs, sec_key = gen_ws_headers(compress=6, extension_text=", permessage-deflate") req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() headers, _, compress, notakeover = ws._handshake(req) assert "Sec-Websocket-Extensions" in headers assert headers["Sec-Websocket-Extensions"] == "permessage-deflate" assert compress == 15 def test_handshake_no_transfer_encoding() -> None: hdrs, sec_key = gen_ws_headers() req = make_mocked_request("GET", "/", headers=hdrs) ws = web.WebSocketResponse() headers, _, compress, notakeover = ws._handshake(req) assert "Transfer-Encoding" not in headers ================================================ FILE: tests/test_websocket_parser.py ================================================ import asyncio import pickle import random import struct from unittest import mock import pytest from aiohttp._websocket import helpers as _websocket_helpers from aiohttp._websocket.helpers import ( PACK_CLOSE_CODE, PACK_LEN1, PACK_LEN2, PACK_LEN3, PACK_RANDBITS, websocket_mask, ) from aiohttp._websocket.models import WS_DEFLATE_TRAILING from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend, ZLibBackendWrapper from aiohttp.http import WebSocketError, WSCloseCode, WSMsgType from aiohttp.http_websocket import ( WebSocketReader, WSMessageBinary, WSMessageClose, WSMessagePing, WSMessagePong, WSMessageText, ) class PatchableWebSocketReader(WebSocketReader): """WebSocketReader subclass that allows for patching parse_frame.""" def parse_frame( self, data: bytes ) -> list[tuple[bool, int, bytes | bytearray, int]]: # This method is overridden to allow for patching in tests. frames: list[tuple[bool, int, bytes | bytearray, int]] = [] def _handle_frame( fin: bool, opcode: int, payload: bytes | bytearray, compressed: int, ) -> None: # This method is overridden to allow for patching in tests. frames.append((fin, opcode, payload, compressed)) with mock.patch.object(self, "_handle_frame", _handle_frame): self._feed_data(data) return frames def build_frame( message: bytes, opcode: int, noheader: bool = False, is_fin: bool = True, ZLibBackend: ZLibBackendWrapper | None = None, mask: bool = False, ) -> bytes: # Send a frame over the websocket with message as its payload. compress = False if ZLibBackend: compress = True compressobj = ZLibBackend.compressobj(wbits=-9) message = compressobj.compress(message) message = message + compressobj.flush(ZLibBackend.Z_SYNC_FLUSH) assert message.endswith(WS_DEFLATE_TRAILING) message = message[:-4] msg_length = len(message) if is_fin: header_first_byte = 0x80 | opcode else: header_first_byte = opcode if compress: header_first_byte |= 0x40 mask_bit = 0x80 if mask else 0 if msg_length < 126: header = PACK_LEN1(header_first_byte, msg_length | mask_bit) elif msg_length < 65536: header = PACK_LEN2(header_first_byte, 126 | mask_bit, msg_length) else: header = PACK_LEN3(header_first_byte, 127 | mask_bit, msg_length) if mask: assert not noheader mask_bytes = PACK_RANDBITS(random.getrandbits(32)) message_arr = bytearray(message) websocket_mask(mask_bytes, message_arr) return header + mask_bytes + message_arr if noheader: return message else: return header + message def build_close_frame( code: int = 1000, message: bytes = b"", noheader: bool = False ) -> bytes: # Close the websocket, sending the specified code and message. return build_frame( PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE, noheader=noheader ) @pytest.fixture() def protocol(loop: asyncio.AbstractEventLoop) -> BaseProtocol: transport = mock.Mock(spec_set=asyncio.Transport) protocol = BaseProtocol(loop) protocol.connection_made(transport) return protocol @pytest.fixture() def out(loop: asyncio.AbstractEventLoop) -> WebSocketDataQueue: return WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) @pytest.fixture() def out_low_limit( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol ) -> WebSocketDataQueue: return WebSocketDataQueue(protocol, 16, loop=loop) @pytest.fixture() def parser_low_limit( out_low_limit: WebSocketDataQueue, ) -> PatchableWebSocketReader: return PatchableWebSocketReader(out_low_limit, 4 * 1024 * 1024) @pytest.fixture() def parser(out: WebSocketDataQueue) -> PatchableWebSocketReader: return PatchableWebSocketReader(out, 4 * 1024 * 1024) def test_feed_data_remembers_exception(parser: WebSocketReader) -> None: """Verify that feed_data remembers an exception was already raised internally.""" error, data = parser.feed_data(struct.pack("!BB", 0b01100000, 0b00000000)) assert error is True assert data == b"" error, data = parser.feed_data(b"") assert error is True assert data == b"" def test_parse_frame(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 0b00000001)) res = parser.parse_frame(b"1") fin, opcode, payload, compress = res[0] assert (0, 1, b"1", 0) == (fin, opcode, payload, not not compress) def test_parse_frame_length0(parser: PatchableWebSocketReader) -> None: fin, opcode, payload, compress = parser.parse_frame( struct.pack("!BB", 0b00000001, 0b00000000) )[0] assert (0, 1, b"", 0) == (fin, opcode, payload, not not compress) def test_parse_frame_length2(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) parser.parse_frame(struct.pack("!H", 4)) res = parser.parse_frame(b"1234") fin, opcode, payload, compress = res[0] assert (0, 1, b"1234", 0) == (fin, opcode, payload, not not compress) def test_parse_frame_length2_multi_byte(parser: PatchableWebSocketReader) -> None: """Ensure a multi-byte length is parsed correctly.""" expected_payload = b"1" * 32768 parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) parser.parse_frame(struct.pack("!H", 32768)) res = parser.parse_frame(b"1" * 32768) fin, opcode, payload, compress = res[0] assert (0, 1, expected_payload, 0) == (fin, opcode, payload, not not compress) def test_parse_frame_length2_multi_byte_multi_packet( parser: PatchableWebSocketReader, ) -> None: """Ensure a multi-byte length with multiple packets is parsed correctly.""" expected_payload = b"1" * 32768 assert parser.parse_frame(struct.pack("!BB", 0b00000001, 126)) == [] assert parser.parse_frame(struct.pack("!H", 32768)) == [] assert parser.parse_frame(b"1" * 8192) == [] assert parser.parse_frame(b"1" * 8192) == [] assert parser.parse_frame(b"1" * 8192) == [] res = parser.parse_frame(b"1" * 8192) fin, opcode, payload, compress = res[0] assert len(payload) == 32768 assert (0, 1, expected_payload, 0) == (fin, opcode, payload, not not compress) def test_parse_frame_length4(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 127)) parser.parse_frame(struct.pack("!Q", 4)) fin, opcode, payload, compress = parser.parse_frame(b"1234")[0] assert (0, 1, b"1234", 0) == (fin, opcode, payload, compress) def test_parse_frame_mask(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b00000001, 0b10000001)) parser.parse_frame(b"0001") fin, opcode, payload, compress = parser.parse_frame(b"1")[0] assert (0, 1, b"\x01", 0) == (fin, opcode, payload, compress) def test_parse_frame_header_reversed_bits( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError): parser.parse_frame(struct.pack("!BB", 0b01100000, 0b00000000)) def test_parse_frame_header_control_frame( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError): parser.parse_frame(struct.pack("!BB", 0b00001000, 0b00000000)) def test_parse_frame_header_new_data_err(parser: PatchableWebSocketReader) -> None: with pytest.raises(WebSocketError) as msg: parser._feed_data(struct.pack("!BB", 0b00000000, 0b00000000)) assert msg.value.code == WSCloseCode.PROTOCOL_ERROR assert str(msg.value) == "Continuation frame for non started message" def test_parse_frame_header_payload_size( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError): parser.parse_frame(struct.pack("!BB", 0b10001000, 0b01111110)) # Protractor event loop will call feed_data with bytearray. Since # asyncio technically supports memoryview as well, we should test that. @pytest.mark.parametrize( argnames="data", argvalues=[b"", bytearray(b""), memoryview(b"")], ids=["bytes", "bytearray", "memoryview"], ) def test_ping_frame( out: WebSocketDataQueue, parser: PatchableWebSocketReader, data: bytes | bytearray | memoryview, ) -> None: parser._handle_frame(True, WSMsgType.PING, b"data", 0) res = out._buffer[0] assert res == WSMessagePing(data=b"data", size=4, extra="") def test_pong_frame(out: WebSocketDataQueue, parser: PatchableWebSocketReader) -> None: parser._handle_frame(True, WSMsgType.PONG, b"data", 0) res = out._buffer[0] assert res == WSMessagePong(data=b"data", size=4, extra="") def test_close_frame(out: WebSocketDataQueue, parser: PatchableWebSocketReader) -> None: parser._handle_frame(True, WSMsgType.CLOSE, b"", 0) res = out._buffer[0] assert res == WSMessageClose(data=0, size=0, extra="") def test_close_frame_info( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: parser._handle_frame(True, WSMsgType.CLOSE, b"0112345", 0) res = out._buffer[0] assert res == WSMessageClose(data=12337, size=7, extra="12345") def test_close_frame_invalid( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError) as ctx: parser._handle_frame(True, WSMsgType.CLOSE, b"1", 0) assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR def test_close_frame_invalid_2( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: data = build_close_frame(code=1) with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR def test_close_frame_unicode_err(parser: PatchableWebSocketReader) -> None: data = build_close_frame(code=1000, message=b"\xf4\x90\x80\x80") with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) assert ctx.value.code == WSCloseCode.INVALID_TEXT def test_unknown_frame( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: with pytest.raises(WebSocketError): parser._handle_frame(True, WSMsgType.CONTINUATION, b"", 0) def test_simple_text(out: WebSocketDataQueue, parser: PatchableWebSocketReader) -> None: data = build_frame(b"text", WSMsgType.TEXT) parser._feed_data(data) res = out._buffer[0] assert res == WSMessageText(data="text", size=4, extra="") def test_simple_text_unicode_err(parser: PatchableWebSocketReader) -> None: data = build_frame(b"\xf4\x90\x80\x80", WSMsgType.TEXT) with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) assert ctx.value.code == WSCloseCode.INVALID_TEXT def test_simple_binary( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: data = build_frame(b"binary", WSMsgType.BINARY) parser._feed_data(data) res = out._buffer[0] assert res == WSMessageBinary(data=b"binary", size=6, extra="") def test_one_byte_at_a_time( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: """Send one byte at a time to the parser.""" data = build_frame(b"binary", WSMsgType.BINARY) for i in range(len(data)): parser._feed_data(data[i : i + 1]) res = out._buffer[0] assert res == WSMessageBinary(data=b"binary", size=6, extra="") def test_fragmentation_header( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: data = build_frame(b"a", WSMsgType.TEXT) parser._feed_data(data[:1]) parser._feed_data(data[1:]) res = out._buffer[0] assert res == WSMessageText(data="a", size=1, extra="") def test_large_message( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: large_payload = b"b" * 131072 data = build_frame(large_payload, WSMsgType.BINARY) parser._feed_data(data) res = out._buffer[0] assert res == WSMessageBinary(data=large_payload, size=131072, extra="") def test_large_masked_message( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: large_payload = b"b" * 131072 data = build_frame(large_payload, WSMsgType.BINARY, mask=True) parser._feed_data(data) res = out._buffer[0] assert res == WSMessageBinary(data=large_payload, size=131072, extra="") def test_fragmented_masked_message( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: large_payload = b"b" * 100 data = build_frame(large_payload, WSMsgType.BINARY, mask=True) for i in range(len(data)): parser._feed_data(data[i : i + 1]) res = out._buffer[0] assert res == WSMessageBinary(data=large_payload, size=100, extra="") def test_large_fragmented_masked_message( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: large_payload = b"b" * 131072 data = build_frame(large_payload, WSMsgType.BINARY, mask=True) for i in range(0, len(data), 16384): parser._feed_data(data[i : i + 16384]) res = out._buffer[0] assert res == WSMessageBinary(data=large_payload, size=131072, extra="") def test_continuation( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: data1 = build_frame(b"line1", WSMsgType.TEXT, is_fin=False) parser._feed_data(data1) data2 = build_frame(b"line2", WSMsgType.CONTINUATION) parser._feed_data(data2) res = out._buffer[0] assert res == WSMessageText(data="line1line2", size=10, extra="") def test_continuation_with_ping( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: data1 = build_frame(b"line1", WSMsgType.TEXT, is_fin=False) parser._feed_data(data1) data2 = build_frame(b"", WSMsgType.PING) parser._feed_data(data2) data3 = build_frame(b"line2", WSMsgType.CONTINUATION) parser._feed_data(data3) res = out._buffer[0] assert res == WSMessagePing(data=b"", size=0, extra="") res = out._buffer[1] assert res == WSMessageText(data="line1line2", size=10, extra="") def test_continuation_err( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) with pytest.raises(WebSocketError): parser._handle_frame(True, WSMsgType.TEXT, b"line2", 0) def test_continuation_with_close( out: WebSocketDataQueue, parser: WebSocketReader ) -> None: parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) parser._handle_frame( False, WSMsgType.CLOSE, build_close_frame(1002, b"test", noheader=True), False, ) parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) res = out._buffer[0] assert res == WSMessageClose(data=1002, size=6, extra="test") res = out._buffer[1] assert res == WSMessageText(data="line1line2", size=10, extra="") def test_continuation_with_close_unicode_err( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) with pytest.raises(WebSocketError) as ctx: parser._handle_frame( False, WSMsgType.CLOSE, build_close_frame(1000, b"\xf4\x90\x80\x80", noheader=True), 0, ) parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) assert ctx.value.code == WSCloseCode.INVALID_TEXT def test_continuation_with_close_bad_code( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) with pytest.raises(WebSocketError) as ctx: parser._handle_frame( False, WSMsgType.CLOSE, build_close_frame(1, b"test", noheader=True), 0 ) assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) def test_continuation_with_close_bad_payload( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) with pytest.raises(WebSocketError) as ctx: parser._handle_frame(False, WSMsgType.CLOSE, b"1", 0) assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) def test_continuation_with_close_empty( out: WebSocketDataQueue, parser: PatchableWebSocketReader ) -> None: parser._handle_frame(False, WSMsgType.TEXT, b"line1", 0) parser._handle_frame(False, WSMsgType.CLOSE, b"", 0) parser._handle_frame(True, WSMsgType.CONTINUATION, b"line2", 0) res = out._buffer[0] assert res == WSMessageClose(data=0, size=0, extra="") res = out._buffer[1] assert res == WSMessageText(data="line1line2", size=10, extra="") websocket_mask_data: bytes = b"some very long data for masking by websocket" websocket_mask_mask: bytes = b"1234" websocket_mask_masked: bytes = ( b"B]^Q\x11DVFH\x12_[_U\x13PPFR\x14W]A\x14\\S@_X\\T\x14SK\x13CTP@[RYV@" ) def test_websocket_mask_python() -> None: message = bytearray(websocket_mask_data) _websocket_helpers._websocket_mask_python(websocket_mask_mask, message) assert message == websocket_mask_masked @pytest.mark.skipif( not hasattr(_websocket_helpers, "_websocket_mask_cython"), reason="Requires Cython" ) def test_websocket_mask_cython() -> None: message = bytearray(websocket_mask_data) _websocket_helpers._websocket_mask_cython(websocket_mask_mask, message) # type: ignore[attr-defined] assert message == websocket_mask_masked assert ( _websocket_helpers.websocket_mask is _websocket_helpers._websocket_mask_cython # type: ignore[attr-defined] ) def test_websocket_mask_python_empty() -> None: message = bytearray() _websocket_helpers._websocket_mask_python(websocket_mask_mask, message) assert message == bytearray() @pytest.mark.skipif( not hasattr(_websocket_helpers, "_websocket_mask_cython"), reason="Requires Cython" ) def test_websocket_mask_cython_empty() -> None: message = bytearray() _websocket_helpers._websocket_mask_cython(websocket_mask_mask, message) # type: ignore[attr-defined] assert message == bytearray() def test_parse_compress_frame_single(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) res = parser.parse_frame(b"1") fin, opcode, payload, compress = res[0] assert (1, 1, b"1", True) == (fin, opcode, payload, not not compress) def test_parse_compress_frame_multi(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b01000001, 126)) parser.parse_frame(struct.pack("!H", 4)) res = parser.parse_frame(b"1234") fin, opcode, payload, compress = res[0] assert (0, 1, b"1234", True) == (fin, opcode, payload, not not compress) parser.parse_frame(struct.pack("!BB", 0b10000001, 126)) parser.parse_frame(struct.pack("!H", 4)) res = parser.parse_frame(b"1234") fin, opcode, payload, compress = res[0] assert (1, 1, b"1234", True) == (fin, opcode, payload, not not compress) parser.parse_frame(struct.pack("!BB", 0b10000001, 126)) parser.parse_frame(struct.pack("!H", 4)) res = parser.parse_frame(b"1234") fin, opcode, payload, compress = res[0] assert (1, 1, b"1234", False) == (fin, opcode, payload, not not compress) def test_parse_compress_error_frame(parser: PatchableWebSocketReader) -> None: parser.parse_frame(struct.pack("!BB", 0b01000001, 0b00000001)) parser.parse_frame(b"1") with pytest.raises(WebSocketError) as ctx: parser.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR def test_parse_no_compress_frame_single(out: WebSocketDataQueue) -> None: parser_no_compress = PatchableWebSocketReader(out, 0, compress=False) with pytest.raises(WebSocketError) as ctx: parser_no_compress.parse_frame(struct.pack("!BB", 0b11000001, 0b00000001)) assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR def test_msg_too_large(out: WebSocketDataQueue) -> None: parser = WebSocketReader(out, 256, compress=False) data = build_frame(b"text" * 256, WSMsgType.TEXT) with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG def test_msg_too_large_not_fin(out: WebSocketDataQueue) -> None: parser = WebSocketReader(out, 256, compress=False) data = build_frame(b"text" * 256, WSMsgType.TEXT, is_fin=False) with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG @pytest.mark.usefixtures("parametrize_zlib_backend") def test_compressed_msg_too_large(out: WebSocketDataQueue) -> None: parser = WebSocketReader(out, 256, compress=True) data = build_frame(b"aaa" * 256, WSMsgType.TEXT, ZLibBackend=ZLibBackend) with pytest.raises(WebSocketError) as ctx: parser._feed_data(data) assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG class TestWebSocketError: def test_ctor(self) -> None: err = WebSocketError(WSCloseCode.PROTOCOL_ERROR, "Something invalid") assert err.code == WSCloseCode.PROTOCOL_ERROR assert str(err) == "Something invalid" def test_pickle(self) -> None: err = WebSocketError(WSCloseCode.PROTOCOL_ERROR, "Something invalid") err.foo = "bar" # type: ignore[attr-defined] for proto in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(err, proto) err2 = pickle.loads(pickled) assert err2.code == WSCloseCode.PROTOCOL_ERROR assert str(err2) == "Something invalid" assert err2.foo == "bar" def test_flow_control_binary( protocol: BaseProtocol, out_low_limit: WebSocketDataQueue, parser_low_limit: PatchableWebSocketReader, ) -> None: large_payload = b"b" * (1 + 16 * 2) large_payload_size = len(large_payload) parser_low_limit._handle_frame(True, WSMsgType.BINARY, large_payload, 0) res = out_low_limit._buffer[0] assert res == WSMessageBinary(data=large_payload, size=large_payload_size, extra="") assert protocol._reading_paused is True def test_flow_control_multi_byte_text( protocol: BaseProtocol, out_low_limit: WebSocketDataQueue, parser_low_limit: PatchableWebSocketReader, ) -> None: large_payload_text = "𒀁" * (1 + 16 * 2) large_payload = large_payload_text.encode("utf-8") large_payload_size = len(large_payload) parser_low_limit._handle_frame(True, WSMsgType.TEXT, large_payload, 0) res = out_low_limit._buffer[0] assert res == WSMessageText( data=large_payload_text, size=large_payload_size, extra="" ) assert protocol._reading_paused is True ================================================ FILE: tests/test_websocket_writer.py ================================================ import asyncio import random from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from contextlib import suppress from unittest import mock import pytest from aiohttp import WSMsgType from aiohttp._websocket.reader import WebSocketDataQueue from aiohttp.base_protocol import BaseProtocol from aiohttp.compression_utils import ZLibBackend from aiohttp.http import WebSocketReader, WebSocketWriter @pytest.fixture def protocol() -> mock.Mock: ret = mock.create_autospec(BaseProtocol, spec_set=True, instance=True) return ret # type: ignore[no-any-return] @pytest.fixture def transport() -> mock.Mock: ret = mock.Mock() ret.is_closing.return_value = False return ret @pytest.fixture def writer(protocol: BaseProtocol, transport: asyncio.Transport) -> WebSocketWriter: return WebSocketWriter(protocol, transport, use_mask=False) async def test_pong(writer: WebSocketWriter) -> None: await writer.send_frame(b"", WSMsgType.PONG) writer.transport.write.assert_called_with(b"\x8a\x00") # type: ignore[attr-defined] async def test_ping(writer: WebSocketWriter) -> None: await writer.send_frame(b"", WSMsgType.PING) writer.transport.write.assert_called_with(b"\x89\x00") # type: ignore[attr-defined] async def test_send_text(writer: WebSocketWriter) -> None: await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\x81\x04text") # type: ignore[attr-defined] async def test_send_binary(writer: WebSocketWriter) -> None: await writer.send_frame(b"binary", WSMsgType.BINARY) writer.transport.write.assert_called_with(b"\x82\x06binary") # type: ignore[attr-defined] async def test_send_binary_long(writer: WebSocketWriter) -> None: await writer.send_frame(b"b" * 127, WSMsgType.BINARY) assert writer.transport.write.call_args[0][0].startswith(b"\x82~\x00\x7fb") # type: ignore[attr-defined] async def test_send_binary_very_long(writer: WebSocketWriter) -> None: await writer.send_frame(b"b" * 65537, WSMsgType.BINARY) assert ( writer.transport.write.call_args_list[0][0][0] # type: ignore[attr-defined] == b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01" ) assert writer.transport.write.call_args_list[1][0][0] == b"b" * 65537 # type: ignore[attr-defined] async def test_close(writer: WebSocketWriter) -> None: await writer.close(1001, "msg") writer.transport.write.assert_called_with(b"\x88\x05\x03\xe9msg") # type: ignore[attr-defined] await writer.close(1001, b"msg") writer.transport.write.assert_called_with(b"\x88\x05\x03\xe9msg") # type: ignore[attr-defined] # Test that Service Restart close code is also supported await writer.close(1012, b"msg") writer.transport.write.assert_called_with(b"\x88\x05\x03\xf4msg") # type: ignore[attr-defined] async def test_send_text_masked( protocol: BaseProtocol, transport: asyncio.Transport ) -> None: writer = WebSocketWriter( protocol, transport, use_mask=True, random=random.Random(123) ) await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\x81\x84\rg\xb3fy\x02\xcb\x12") # type: ignore[attr-defined] @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_send_compress_text( protocol: BaseProtocol, transport: asyncio.Transport ) -> None: compress_obj = ZLibBackend.compressobj(level=ZLibBackend.Z_BEST_SPEED, wbits=-15) writer = WebSocketWriter(protocol, transport, compress=15) msg = ( compress_obj.compress(b"text") + compress_obj.flush(ZLibBackend.Z_SYNC_FLUSH) ).removesuffix(b"\x00\x00\xff\xff") await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with( # type: ignore[attr-defined] b"\xc1" + len(msg).to_bytes(1, "big") + msg ) msg = ( compress_obj.compress(b"text") + compress_obj.flush(ZLibBackend.Z_SYNC_FLUSH) ).removesuffix(b"\x00\x00\xff\xff") await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with( # type: ignore[attr-defined] b"\xc1" + len(msg).to_bytes(1, "big") + msg ) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_send_compress_text_notakeover( protocol: BaseProtocol, transport: asyncio.Transport ) -> None: compress_obj = ZLibBackend.compressobj(level=ZLibBackend.Z_BEST_SPEED, wbits=-15) writer = WebSocketWriter(protocol, transport, compress=15, notakeover=True) msg = ( compress_obj.compress(b"text") + compress_obj.flush(ZLibBackend.Z_FULL_FLUSH) ).removesuffix(b"\x00\x00\xff\xff") await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with( # type: ignore[attr-defined] b"\xc1" + len(msg).to_bytes(1, "big") + msg ) await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with( # type: ignore[attr-defined] b"\xc1" + len(msg).to_bytes(1, "big") + msg ) async def test_send_compress_text_per_message( protocol: BaseProtocol, transport: asyncio.Transport ) -> None: writer = WebSocketWriter(protocol, transport) await writer.send_frame(b"text", WSMsgType.TEXT, compress=15) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\x81\x04text") # type: ignore[attr-defined] await writer.send_frame(b"text", WSMsgType.TEXT, compress=15) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_send_compress_cancelled( protocol: BaseProtocol, transport: asyncio.Transport, slow_executor: ThreadPoolExecutor, monkeypatch: pytest.MonkeyPatch, ) -> None: """Test that cancelled compression doesn't corrupt subsequent sends. Regression test for https://github.com/aio-libs/aiohttp/issues/11725 """ monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024) writer = WebSocketWriter(protocol, transport, compress=15) loop = asyncio.get_running_loop() queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) reader = WebSocketReader(queue, 50000) # Replace executor with slow one to make race condition reproducible writer._compressobj = writer._get_compressor(None) writer._compressobj._executor = slow_executor # Create large data that will trigger executor-based compression large_data_1 = b"A" * 10000 large_data_2 = b"B" * 10000 # Start first send and cancel it during compression async def send_and_cancel() -> None: await writer.send_frame(large_data_1, WSMsgType.BINARY) task = asyncio.create_task(send_and_cancel()) # Give it a moment to start compression await asyncio.sleep(0.01) task.cancel() # Await task cancellation (expected and intentionally ignored) with suppress(asyncio.CancelledError): await task # Send second message - this should NOT be corrupted await writer.send_frame(large_data_2, WSMsgType.BINARY) # Verify the second send produced correct data last_call = writer.transport.write.call_args_list[-1] # type: ignore[attr-defined] call_bytes = last_call[0][0] result, _ = reader.feed_data(call_bytes) assert result is False msg = await queue.read() assert msg.type is WSMsgType.BINARY # The data should be all B's, not mixed with A's from the cancelled send assert msg.data == large_data_2 @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_send_compress_multiple_cancelled( protocol: BaseProtocol, transport: asyncio.Transport, slow_executor: ThreadPoolExecutor, monkeypatch: pytest.MonkeyPatch, ) -> None: """Test that multiple compressed sends all complete despite cancellation. Regression test for https://github.com/aio-libs/aiohttp/issues/11725 This verifies that once a send operation enters the shield, it completes even if cancelled. With the lock inside the shield, all tasks that enter the shield will complete their sends, even while waiting for the lock. """ monkeypatch.setattr("aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", 1024) writer = WebSocketWriter(protocol, transport, compress=15) loop = asyncio.get_running_loop() queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) reader = WebSocketReader(queue, 50000) # Replace executor with slow one writer._compressobj = writer._get_compressor(None) writer._compressobj._executor = slow_executor # Create 5 large messages with different content messages = [bytes([ord("A") + i]) * 10000 for i in range(5)] # Start sending all 5 messages - they'll queue due to the lock tasks = [ asyncio.create_task(writer.send_frame(msg, WSMsgType.BINARY)) for msg in messages ] # Cancel all tasks during execution # With lock inside shield, all tasks that enter the shield will complete # even while waiting for the lock await asyncio.sleep(0.1) # Let tasks enter the shield for task in tasks: task.cancel() # Collect results cancelled_count = 0 for task in tasks: try: await task except asyncio.CancelledError: cancelled_count += 1 # Wait for all background tasks to complete # (they continue running even after cancellation due to shield) await asyncio.gather(*writer._background_tasks, return_exceptions=True) # All tasks that entered the shield should complete, even if cancelled # With lock inside shield, all tasks enter shield immediately then wait for lock sent_count = len(writer.transport.write.call_args_list) # type: ignore[attr-defined] assert ( sent_count == 5 ), "All 5 sends should complete due to shield protecting lock acquisition" # Verify all sent messages are correct (no corruption) for i in range(sent_count): call = writer.transport.write.call_args_list[i] # type: ignore[attr-defined] call_bytes = call[0][0] result, _ = reader.feed_data(call_bytes) assert result is False msg = await queue.read() assert msg.type is WSMsgType.BINARY # Verify the data matches the expected message expected_byte = bytes([ord("A") + i]) assert msg.data == expected_byte * 10000, f"Message {i} corrupted" @pytest.mark.parametrize( ("max_sync_chunk_size", "payload_point_generator"), ( (16, lambda count: count), (4096, lambda count: count), (32, lambda count: 64 + count if count % 2 else count), ), ) @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_concurrent_messages( protocol: BaseProtocol, transport: asyncio.Transport, max_sync_chunk_size: int, payload_point_generator: Callable[[int], int], ) -> None: """Ensure messages are compressed correctly when there are multiple concurrent writers. This test generates is parametrized to - Generate messages that are larger than patch WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 16 where compression will run in the executor - Generate messages that are smaller than patch WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 4096 where compression will run in the event loop - Interleave generated messages with a WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 32 where compression will run in the event loop and in the executor """ with mock.patch( "aiohttp._websocket.writer.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", max_sync_chunk_size ): writer = WebSocketWriter(protocol, transport, compress=15) loop = asyncio.get_running_loop() queue = WebSocketDataQueue(mock.Mock(_reading_paused=False), 2**16, loop=loop) reader = WebSocketReader(queue, 50000) writers = [] payloads = [] for count in range(1, 64 + 1): point = payload_point_generator(count) payload = bytes((point,)) * point payloads.append(payload) writers.append(writer.send_frame(payload, WSMsgType.BINARY)) await asyncio.gather(*writers) for call in writer.transport.write.call_args_list: # type: ignore[attr-defined] call_bytes = call[0][0] result, _ = reader.feed_data(call_bytes) assert result is False msg = await queue.read() assert msg.type is WSMsgType.BINARY bytes_data = msg.data first_char = bytes_data[0:1] char_val = ord(first_char) assert len(bytes_data) == char_val # If we have a concurrency problem, the data # tends to get mixed up between messages so # we want to validate that all the bytes are # the same value assert bytes_data == bytes_data[0:1] * char_val # Wait for any background tasks to complete await asyncio.gather(*writer._background_tasks, return_exceptions=True) ================================================ FILE: tests/test_worker.py ================================================ # Tests for aiohttp/worker.py import asyncio import os import socket import ssl from typing import TYPE_CHECKING from unittest import mock import pytest from _pytest.fixtures import SubRequest from aiohttp import web if TYPE_CHECKING: from aiohttp import worker as base_worker else: base_worker = pytest.importorskip("aiohttp.worker") try: import uvloop except ImportError: uvloop = None # type: ignore[assignment] WRONG_LOG_FORMAT = '%a "%{Referrer}i" %(h)s %(l)s %s' ACCEPTABLE_LOG_FORMAT = '%a "%{Referrer}i" %s' class BaseTestWorker: def __init__(self) -> None: self.servers: dict[object, object] = {} self.exit_code = 0 self._notify_waiter: asyncio.Future[bool] | None = None self.cfg = mock.Mock() self.cfg.graceful_timeout = 100 self.pid = "pid" self.wsgi = web.Application() class AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker): pass PARAMS = [AsyncioWorker] if uvloop is not None: class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): pass PARAMS.append(UvloopWorker) @pytest.fixture(params=PARAMS) def worker( request: SubRequest, loop: asyncio.AbstractEventLoop ) -> base_worker.GunicornWebWorker: asyncio.set_event_loop(loop) ret = request.param() ret.notify = mock.Mock() return ret # type: ignore[no-any-return] def test_init_process(worker: base_worker.GunicornWebWorker) -> None: with mock.patch("aiohttp.worker.asyncio") as m_asyncio: try: worker.init_process() except TypeError: pass assert m_asyncio.new_event_loop.called assert m_asyncio.set_event_loop.called def test_run( worker: base_worker.GunicornWebWorker, loop: asyncio.AbstractEventLoop ) -> None: worker.log = mock.Mock() worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT worker.cfg.is_ssl = False worker.cfg.graceful_timeout = 100 worker.sockets = [] worker.loop = loop with pytest.raises(SystemExit): worker.run() worker.log.exception.assert_not_called() assert loop.is_closed() def test_run_async_factory( worker: base_worker.GunicornWebWorker, loop: asyncio.AbstractEventLoop ) -> None: worker.log = mock.Mock() worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT worker.cfg.is_ssl = False worker.cfg.graceful_timeout = 100 worker.sockets = [] app = worker.wsgi async def make_app() -> web.Application: return app # type: ignore[no-any-return] worker.wsgi = make_app worker.loop = loop worker.alive = False with pytest.raises(SystemExit): worker.run() worker.log.exception.assert_not_called() assert loop.is_closed() def test_run_not_app( worker: base_worker.GunicornWebWorker, loop: asyncio.AbstractEventLoop ) -> None: worker.log = mock.Mock() worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT worker.loop = loop worker.wsgi = "not-app" worker.alive = False with pytest.raises(SystemExit): worker.run() worker.log.exception.assert_called_with("Exception in gunicorn worker") assert loop.is_closed() def test_handle_abort(worker: base_worker.GunicornWebWorker) -> None: with mock.patch("aiohttp.worker.sys") as m_sys: worker.handle_abort(0, None) assert not worker.alive assert worker.exit_code == 1 m_sys.exit.assert_called_with(1) def test__wait_next_notify(worker: base_worker.GunicornWebWorker) -> None: worker.loop = mloop = mock.create_autospec(asyncio.AbstractEventLoop) with mock.patch.object(worker, "_notify_waiter_done", autospec=True): fut = worker._wait_next_notify() assert worker._notify_waiter == fut mloop.call_later.assert_called_with(1.0, worker._notify_waiter_done, fut) def test__notify_waiter_done(worker: base_worker.GunicornWebWorker) -> None: worker._notify_waiter = None worker._notify_waiter_done() assert worker._notify_waiter is None waiter = worker._notify_waiter = mock.Mock() worker._notify_waiter.done.return_value = False worker._notify_waiter_done() assert worker._notify_waiter is None waiter.set_result.assert_called_with(True) # type: ignore[unreachable] def test__notify_waiter_done_explicit_waiter( worker: base_worker.GunicornWebWorker, ) -> None: worker._notify_waiter = None assert worker._notify_waiter is None waiter = worker._notify_waiter = mock.Mock() waiter.done.return_value = False waiter2 = worker._notify_waiter = mock.Mock() worker._notify_waiter_done(waiter) assert worker._notify_waiter is waiter2 waiter.set_result.assert_called_with(True) assert not waiter2.set_result.called def test_init_signals(worker: base_worker.GunicornWebWorker) -> None: worker.loop = mock.Mock() worker.init_signals() assert worker.loop.add_signal_handler.called @pytest.mark.parametrize( "source,result", [ (ACCEPTABLE_LOG_FORMAT, ACCEPTABLE_LOG_FORMAT), ( AsyncioWorker.DEFAULT_GUNICORN_LOG_FORMAT, AsyncioWorker.DEFAULT_AIOHTTP_LOG_FORMAT, ), ], ) def test__get_valid_log_format_ok( worker: base_worker.GunicornWebWorker, source: str, result: str ) -> None: assert result == worker._get_valid_log_format(source) def test__get_valid_log_format_exc(worker: base_worker.GunicornWebWorker) -> None: with pytest.raises(ValueError) as exc: worker._get_valid_log_format(WRONG_LOG_FORMAT) assert "%(name)s" in str(exc.value) async def test__run_ok_parent_changed( worker: base_worker.GunicornWebWorker, loop: asyncio.AbstractEventLoop, unused_port_socket: socket.socket, ) -> None: worker.ppid = 0 worker.alive = True sock = unused_port_socket worker.sockets = [sock] worker.log = mock.Mock() worker.loop = loop worker.max_requests = 0 worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT worker.cfg.is_ssl = False await worker._run() worker.notify.assert_called_with() worker.log.info.assert_called_with("Parent changed, shutting down: %s", worker) async def test__run_exc( worker: base_worker.GunicornWebWorker, loop: asyncio.AbstractEventLoop, unused_port_socket: socket.socket, ) -> None: worker.ppid = os.getppid() worker.alive = True sock = unused_port_socket worker.sockets = [sock] worker.log = mock.Mock() worker.loop = loop worker.max_requests = 0 worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT worker.cfg.is_ssl = False def raiser() -> None: waiter = worker._notify_waiter worker.alive = False assert waiter is not None waiter.set_exception(RuntimeError()) loop.call_later(0.1, raiser) await worker._run() worker.notify.assert_called_with() def test__create_ssl_context_without_certs_and_ciphers( worker: base_worker.GunicornWebWorker, tls_certificate_pem_path: str, ) -> None: worker.cfg.ssl_version = ssl.PROTOCOL_TLS_CLIENT worker.cfg.cert_reqs = ssl.CERT_OPTIONAL worker.cfg.certfile = tls_certificate_pem_path worker.cfg.keyfile = tls_certificate_pem_path worker.cfg.ca_certs = None worker.cfg.ciphers = None ctx = worker._create_ssl_context(worker.cfg) assert isinstance(ctx, ssl.SSLContext) def test__create_ssl_context_with_ciphers( worker: base_worker.GunicornWebWorker, tls_certificate_pem_path: str, ) -> None: worker.cfg.ssl_version = ssl.PROTOCOL_TLS_CLIENT worker.cfg.cert_reqs = ssl.CERT_OPTIONAL worker.cfg.certfile = tls_certificate_pem_path worker.cfg.keyfile = tls_certificate_pem_path worker.cfg.ca_certs = None worker.cfg.ciphers = "3DES PSK" ctx = worker._create_ssl_context(worker.cfg) assert isinstance(ctx, ssl.SSLContext) def test__create_ssl_context_with_ca_certs( worker: base_worker.GunicornWebWorker, tls_ca_certificate_pem_path: str, tls_certificate_pem_path: str, ) -> None: worker.cfg.ssl_version = ssl.PROTOCOL_TLS_CLIENT worker.cfg.cert_reqs = ssl.CERT_OPTIONAL worker.cfg.certfile = tls_certificate_pem_path worker.cfg.keyfile = tls_certificate_pem_path worker.cfg.ca_certs = tls_ca_certificate_pem_path worker.cfg.ciphers = None ctx = worker._create_ssl_context(worker.cfg) assert isinstance(ctx, ssl.SSLContext) ================================================ FILE: tools/bench-asyncio-write.py ================================================ import asyncio import atexit import math import os import signal PORT = 8888 server = os.fork() if server == 0: loop = asyncio.get_event_loop() coro = asyncio.start_server(lambda *_: None, port=PORT) loop.run_until_complete(coro) loop.run_forever() else: atexit.register(os.kill, server, signal.SIGTERM) async def write_joined_bytearray(writer, chunks): body = bytearray(chunks[0]) for c in chunks[1:]: body += c writer.write(body) async def write_joined_list(writer, chunks): body = b"".join(chunks) writer.write(body) async def write_separately(writer, chunks): for c in chunks: writer.write(c) def fm_size(s, _fms=("", "K", "M", "G")): i = 0 while s >= 1024: s /= 1024 i += 1 return f"{s:.0f}{_fms[i]}B" def fm_time(s, _fms=("", "m", "µ", "n")): if s == 0: return "0" i = 0 while s < 1: s *= 1000 i += 1 return f"{s:.2f}{_fms[i]}s" def _job(j: list[int]) -> tuple[str, list[bytes]]: # Always start with a 256B headers chunk body = [b"0" * s for s in [256] + list(j)] job_title = f"{fm_size(sum(j))} / {len(j)}" return (job_title, body) writes = [ ("b''.join", write_joined_list), ("bytearray", write_joined_bytearray), ("multiple writes", write_separately), ] bodies = ( [], [10 * 2**0], [10 * 2**7], [10 * 2**17], [10 * 2**27], [50 * 2**27], [1 * 2**0 for _ in range(10)], [1 * 2**7 for _ in range(10)], [1 * 2**17 for _ in range(10)], [1 * 2**27 for _ in range(10)], [10 * 2**27 for _ in range(5)], ) jobs = [_job(j) for j in bodies] async def time(loop, fn, *args): spent = [] while not spent or sum(spent) < 0.2: s = loop.time() await fn(*args) e = loop.time() spent.append(e - s) mean = sum(spent) / len(spent) sd = sum((x - mean) ** 2 for x in spent) / len(spent) return len(spent), mean, math.sqrt(sd) async def main(loop): _, writer = await asyncio.open_connection(port=PORT) print("Loop:", loop) print("Transport:", writer._transport) res = [ ("size/chunks", "Write option", "Mean", "Std dev", "loops", "Variation"), ] res.append([":---", ":---", "---:", "---:", "---:", "---:"]) async def bench(job_title, w, body, base=None): it, mean, sd = await time(loop, w[1], writer, c) res.append( ( job_title, w[0], fm_time(mean), fm_time(sd), str(it), f"{mean / base - 1:.2%}" if base is not None else "", ) ) return mean for t, c in jobs: print("Doing", t) base = await bench(t, writes[0], c) for w in writes[1:]: await bench("", w, c, base) return res loop = asyncio.get_event_loop() results = loop.run_until_complete(main(loop)) with open("bench.md", "w") as f: for line in results: f.write("| {} |\n".format(" | ".join(line))) ================================================ FILE: tools/check_changes.py ================================================ #!/usr/bin/env python3 import re import sys from pathlib import Path ALLOWED_SUFFIXES = ( "bugfix", "feature", "deprecation", "breaking", "doc", "packaging", "contrib", "misc", ) PATTERN = re.compile( r"(\d+|[0-9a-f]{8}|[0-9a-f]{7}|[0-9a-f]{40})\.(" + "|".join(ALLOWED_SUFFIXES) + r")(\.\d+)?(\.rst)?", ) def get_root(script_path): folder = script_path.resolve().parent while not (folder / ".git").exists(): folder = folder.parent if folder == folder.anchor: raise RuntimeError("git repo not found") return folder def main(argv): print('Check "CHANGES" folder... ', end="", flush=True) here = Path(argv[0]) root = get_root(here) changes = root / "CHANGES" failed = False for fname in changes.iterdir(): if fname.name in (".gitignore", ".TEMPLATE.rst", "README.rst"): continue if not PATTERN.match(fname.name): if not failed: print("") print("Illegal CHANGES record", fname, file=sys.stderr) failed = True if failed: print("", file=sys.stderr) print("See ./CHANGES/README.rst for the naming instructions", file=sys.stderr) print("", file=sys.stderr) else: print("OK") return int(failed) if __name__ == "__main__": sys.exit(main(sys.argv)) ================================================ FILE: tools/check_sum.py ================================================ #!/usr/bin/env python import argparse import hashlib import pathlib import sys PARSER = argparse.ArgumentParser( description="Helper for check file hashes in Makefile instead of bare timestamps" ) PARSER.add_argument("dst", metavar="DST", type=pathlib.Path) PARSER.add_argument("-d", "--debug", action="store_true", default=False) def main(argv): args = PARSER.parse_args(argv) dst = args.dst assert dst.suffix == ".hash" dirname = dst.parent if dirname.name != ".hash": if args.debug: print(f"Invalid name {dst} -> dirname {dirname}", file=sys.stderr) return 0 dirname.mkdir(exist_ok=True) src_dir = dirname.parent src_name = dst.stem # drop .hash full_src = src_dir / src_name hasher = hashlib.sha256() try: hasher.update(full_src.read_bytes()) except OSError: if args.debug: print(f"Cannot open {full_src}", file=sys.stderr) return 0 src_hash = hasher.hexdigest() if dst.exists(): dst_hash = dst.read_text() else: dst_hash = "" if src_hash != dst_hash: dst.write_text(src_hash) print(f"re-hash {src_hash}") else: if args.debug: print(f"Skip {src_hash} checksum, up-to-date") return 0 if __name__ == "__main__": sys.exit(main(sys.argv[1:])) ================================================ FILE: tools/cleanup_changes.py ================================================ #!/usr/bin/env python # Run me after the backport branch release to cleanup CHANGES records # that was backported and published. import re import subprocess from pathlib import Path ALLOWED_SUFFIXES = ( "bugfix", "feature", "deprecation", "breaking", "doc", "packaging", "contrib", "misc", ) PATTERN = re.compile( r"(\d+|[0-9a-f]{8}|[0-9a-f]{7}|[0-9a-f]{40})\.(" + "|".join(ALLOWED_SUFFIXES) + r")(\.\d+)?(\.rst)?", ) def main(): root = Path(__file__).parent.parent delete = [] changes = (root / "CHANGES.rst").read_text() for fname in (root / "CHANGES").iterdir(): match = PATTERN.match(fname.name) if match is not None: commit_issue_or_pr = match.group(1) tst_issue_or_pr = f":issue:`{commit_issue_or_pr}`" tst_commit = f":commit:`{commit_issue_or_pr}`" if tst_issue_or_pr in changes or tst_commit in changes: subprocess.run(["git", "rm", fname]) delete.append(fname.name) print("Deleted CHANGES records:", " ".join(delete)) print("Please verify and commit") if __name__ == "__main__": main() ================================================ FILE: tools/drop_merged_branches.sh ================================================ #!/usr/bin/env bash git remote prune origin ================================================ FILE: tools/gen.py ================================================ #!/usr/bin/env python import io import pathlib from collections import defaultdict import multidict ROOT = pathlib.Path.cwd() while ROOT.parent != ROOT and not (ROOT / "pyproject.toml").exists(): ROOT = ROOT.parent def calc_headers(root): hdrs_file = root / "aiohttp/hdrs.py" code = compile(hdrs_file.read_text(), str(hdrs_file), "exec") globs = {} exec(code, globs) headers = [val for val in globs.values() if isinstance(val, multidict.istr)] return sorted(headers) headers = calc_headers(ROOT) def factory(): return defaultdict(factory) TERMINAL = object() def build(headers): dct = defaultdict(factory) for hdr in headers: d = dct for ch in hdr: d = d[ch] d[TERMINAL] = hdr return dct dct = build(headers) HEADER = """\ /* The file is autogenerated from aiohttp/hdrs.py Run ./tools/gen.py to update it after the origin changing. */ #include "_find_header.h" #define NEXT_CHAR() \\ { \\ count++; \\ if (count == size) { \\ /* end of search */ \\ return -1; \\ } \\ pchar++; \\ ch = *pchar; \\ last = (count == size -1); \\ } while(0); int find_header(const char *str, int size) { char *pchar = str; int last; char ch; int count = -1; pchar--; """ BLOCK = """ {label} NEXT_CHAR(); switch (ch) {{ {cases} default: return -1; }} """ CASE = """\ case '{char}': if (last) {{ return {index}; }} goto {next};""" FOOTER = """ {missing} missing: /* nothing found */ return -1; }} """ def gen_prefix(prefix, k): if k == "-": return prefix + "_" else: return prefix + k.upper() def gen_block(dct, prefix, used_blocks, missing, out): cases = {} for k, v in dct.items(): if k is TERMINAL: continue next_prefix = gen_prefix(prefix, k) term = v.get(TERMINAL) if term is not None: index = headers.index(term) else: index = -1 hi = k.upper() lo = k.lower() case = CASE.format(char=hi, index=index, next=next_prefix) cases[hi] = case if lo != hi: case = CASE.format(char=lo, index=index, next=next_prefix) cases[lo] = case label = prefix + ":" if prefix else "" if cases: block = BLOCK.format(label=label, cases="\n".join(cases.values())) out.write(block) else: missing.add(label) for k, v in dct.items(): if not isinstance(v, defaultdict): continue block_name = gen_prefix(prefix, k) if block_name in used_blocks: continue used_blocks.add(block_name) gen_block(v, block_name, used_blocks, missing, out) def gen(dct): out = io.StringIO() out.write(HEADER) missing = set() gen_block(dct, "", set(), missing, out) missing_labels = "\n".join(sorted(missing)) out.write(FOOTER.format(missing=missing_labels)) return out def gen_headers(headers): out = io.StringIO() out.write("# The file is autogenerated from aiohttp/hdrs.py\n") out.write("# Run ./tools/gen.py to update it after the origin changing.") out.write("\n\n") out.write("from . import hdrs\n") out.write("cdef tuple headers = (\n") for hdr in headers: out.write(" hdrs.{},\n".format(hdr.upper().replace("-", "_"))) out.write(")\n") return out # print(gen(dct).getvalue()) # print(gen_headers(headers).getvalue()) folder = ROOT / "aiohttp" with (folder / "_find_header.c").open("w") as f: f.write(gen(dct).getvalue()) with (folder / "_headers.pxi").open("w") as f: f.write(gen_headers(headers).getvalue()) ================================================ FILE: tools/testing/Dockerfile ================================================ ARG PYTHON_VERSION FROM python:$PYTHON_VERSION ARG AIOHTTP_NO_EXTENSIONS ENV AIOHTTP_NO_EXTENSIONS=$AIOHTTP_NO_EXTENSIONS WORKDIR /deps ADD ./requirements ./requirements ADD Makefile . RUN make install ADD ./tools/testing/entrypoint.sh / WORKDIR /src ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] ================================================ FILE: tools/testing/Dockerfile.dockerignore ================================================ * !/requirements !/Makefile !/tools/testing/entrypoint.sh ================================================ FILE: tools/testing/entrypoint.sh ================================================ #!/bin/bash [[ "$AIOHTTP_NO_EXTENSIONS" != "y" ]] && make cythonize python -m pytest -qx --no-cov $1 ================================================ FILE: vendor/README.rst ================================================ LLHTTP ------ When building aiohttp from source, there is a pure Python parser used by default. For better performance, you may want to build the higher performance C parser. To build this ``llhttp`` parser, first get/update the submodules (to update to a newer release, add ``--remote``):: git submodule update --init --recursive Then build ``llhttp``:: cd vendor/llhttp/ npm ci make Then build our parser:: cd - make cythonize Then you can build or install it with ``python -m build`` or ``pip install -e .``