Repository: MagicStack/asyncpg Branch: master Commit: db8ecc2a38e1 Files: 121 Total size: 1.1 MB Directory structure: gitextract_qed7mic1/ ├── .clang-format ├── .clangd ├── .flake8 ├── .github/ │ ├── ISSUE_TEMPLATE.md │ ├── RELEASING.rst │ ├── release_log.py │ └── workflows/ │ ├── install-krb5.sh │ ├── install-postgres.sh │ ├── release.yml │ └── tests.yml ├── .gitignore ├── .gitmodules ├── AUTHORS ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── asyncpg/ │ ├── .gitignore │ ├── __init__.py │ ├── _asyncio_compat.py │ ├── _testbase/ │ │ ├── __init__.py │ │ └── fuzzer.py │ ├── _version.py │ ├── cluster.py │ ├── compat.py │ ├── connect_utils.py │ ├── connection.py │ ├── connresource.py │ ├── cursor.py │ ├── exceptions/ │ │ ├── __init__.py │ │ └── _base.py │ ├── introspection.py │ ├── pool.py │ ├── prepared_stmt.py │ ├── protocol/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── codecs/ │ │ │ ├── __init__.py │ │ │ ├── array.pyx │ │ │ ├── base.pxd │ │ │ ├── base.pyx │ │ │ ├── pgproto.pyx │ │ │ ├── range.pyx │ │ │ ├── record.pyx │ │ │ └── textutils.pyx │ │ ├── consts.pxi │ │ ├── coreproto.pxd │ │ ├── coreproto.pyx │ │ ├── cpythonx.pxd │ │ ├── encodings.pyx │ │ ├── pgtypes.pxi │ │ ├── prepared_stmt.pxd │ │ ├── prepared_stmt.pyx │ │ ├── protocol.pxd │ │ ├── protocol.pyi │ │ ├── protocol.pyx │ │ ├── record/ │ │ │ ├── pythoncapi_compat.h │ │ │ ├── pythoncapi_compat_extras.h │ │ │ ├── recordobj.c │ │ │ └── recordobj.h │ │ ├── record.pyi │ │ ├── recordcapi.pxd │ │ ├── scram.pxd │ │ ├── scram.pyx │ │ ├── settings.pxd │ │ └── settings.pyx │ ├── serverversion.py │ ├── transaction.py │ ├── types.py │ └── utils.py ├── docs/ │ ├── .gitignore │ ├── Makefile │ ├── _static/ │ │ └── theme_overrides.css │ ├── api/ │ │ └── index.rst │ ├── conf.py │ ├── faq.rst │ ├── index.rst │ ├── installation.rst │ ├── requirements.txt │ └── usage.rst ├── pyproject.toml ├── setup.py ├── tests/ │ ├── __init__.py │ ├── certs/ │ │ ├── ca.cert.pem │ │ ├── ca.crl.pem │ │ ├── ca.key.pem │ │ ├── client.cert.pem │ │ ├── client.csr.pem │ │ ├── client.key.pem │ │ ├── client.key.protected.pem │ │ ├── client_ca.cert.pem │ │ ├── client_ca.cert.srl │ │ ├── client_ca.key.pem │ │ ├── gen.py │ │ ├── server.cert.pem │ │ ├── server.crl.pem │ │ └── server.key.pem │ ├── test__environment.py │ ├── test__sourcecode.py │ ├── test_adversity.py │ ├── test_cache_invalidation.py │ ├── test_cancellation.py │ ├── test_codecs.py │ ├── test_connect.py │ ├── test_copy.py │ ├── test_cursor.py │ ├── test_exceptions.py │ ├── test_execute.py │ ├── test_introspection.py │ ├── test_listeners.py │ ├── test_logging.py │ ├── test_pool.py │ ├── test_prepare.py │ ├── test_record.py │ ├── test_subinterpreters.py │ ├── test_test.py │ ├── test_timeout.py │ ├── test_transaction.py │ ├── test_types.py │ └── test_utils.py └── tools/ ├── generate_exceptions.py └── generate_type_map.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .clang-format ================================================ # A clang-format style that approximates Python's PEP 7 BasedOnStyle: Google AlwaysBreakAfterReturnType: All AllowShortIfStatementsOnASingleLine: false AlignAfterOpenBracket: Align BreakBeforeBraces: Stroustrup ColumnLimit: 95 DerivePointerAlignment: false IndentWidth: 4 Language: Cpp PointerAlignment: Right ReflowComments: true SpaceBeforeParens: ControlStatements SpacesInParentheses: false TabWidth: 4 UseTab: Never SortIncludes: false ================================================ FILE: .clangd ================================================ Diagnostics: Includes: IgnoreHeader: - "pythoncapi_compat.*\\.h" ================================================ FILE: .flake8 ================================================ [flake8] select = C90,E,F,W,Y0 ignore = E402,E731,W503,W504,E252 exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox per-file-ignores = *.pyi: F401,F403,F405,F811,E127,E128,E203,E266,E301,E302,E305,E501,E701,E704,E741,B303,W503,W504 ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ * **asyncpg version**: * **PostgreSQL version**: * **Do you use a PostgreSQL SaaS? If so, which? Can you reproduce the issue with a local PostgreSQL install?**: * **Python version**: * **Platform**: * **Do you use pgbouncer?**: * **Did you install asyncpg with pip?**: * **If you built asyncpg locally, which version of Cython did you use?**: * **Can the issue be reproduced under both asyncio and [uvloop](https://github.com/magicstack/uvloop)?**: ================================================ FILE: .github/RELEASING.rst ================================================ Releasing asyncpg ================= When making an asyncpg release follow the below checklist. 1. Remove the ``.dev0`` suffix from ``__version__`` in ``asyncpg/__init__.py``. 2. Make a release commit: .. code-block:: shell $ git commit -a -m "asyncpg vX.Y.0" Here, X.Y.0 is the ``__version__`` in ``asyncpg/__init__.py``. 3. Force push into the "releases" branch on Github: .. code-block:: shell $ git push --force origin master:releases 4. Wait for CI to make the release build. If there are errors, investigate, fix and repeat steps 2 through 4. 5. Prepare the release changelog by cleaning and categorizing the output of ``.github/release_log.py``. Look at previous releases for examples of changelog formatting: .. code-block:: shell $ .github/release_log.py 6. Make an annotated, signed git tag and use the changelog as the tag annotation: .. code-block:: shell $ git tag -s vX.Y.0 7. Push the release commit and the new tag to master on Github: .. code-block:: shell $ git push --follow-tags 8. Wait for CI to publish the build to PyPI. 9. Edit the release on Github and paste the same content you used for the tag annotation (Github treats tag annotations as plain text, rather than Markdown.) 10. Open master for development by bumping the minor component of ``__version__`` in ``asyncpg/__init__.py`` and appending the ``.dev0`` suffix. ================================================ FILE: .github/release_log.py ================================================ #!/usr/bin/env python3 # # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import json import requests import re import sys BASE_URL = 'https://api.github.com/repos/magicstack/asyncpg/compare' def main(): if len(sys.argv) < 2: print('pass a sha1 hash as a first argument') sys.exit(1) from_hash = sys.argv[1] if len(sys.argv) > 2: to_hash = sys.argv[2] r = requests.get(f'{BASE_URL}/{from_hash}...{to_hash}') data = json.loads(r.text) for commit in data['commits']: message = commit['commit']['message'] first_line = message.partition('\n\n')[0] if commit.get('author'): username = '@{}'.format(commit['author']['login']) else: username = commit['commit']['author']['name'] sha = commit["sha"][:8] m = re.search(r'\#(?P\d+)\b', message) if m: issue_num = m.group('num') else: issue_num = None print(f'* {first_line}') print(f' (by {username} in {sha}', end='') print(')') print() if __name__ == '__main__': main() ================================================ FILE: .github/workflows/install-krb5.sh ================================================ #!/bin/bash set -Eexuo pipefail shopt -s nullglob if [[ $OSTYPE == linux* ]]; then if [ "$(id -u)" = "0" ]; then SUDO= else SUDO=sudo fi if [ -e /etc/os-release ]; then source /etc/os-release elif [ -e /etc/centos-release ]; then ID="centos" VERSION_ID=$(cat /etc/centos-release | cut -f3 -d' ' | cut -f1 -d.) else echo "install-krb5.sh: cannot determine which Linux distro this is" >&2 exit 1 fi if [ "${ID}" = "debian" -o "${ID}" = "ubuntu" ]; then export DEBIAN_FRONTEND=noninteractive $SUDO apt-get update $SUDO apt-get install -y --no-install-recommends \ libkrb5-dev krb5-user krb5-kdc krb5-admin-server elif [ "${ID}" = "almalinux" ]; then $SUDO dnf install -y krb5-server krb5-workstation krb5-libs krb5-devel elif [ "${ID}" = "centos" ]; then $SUDO yum install -y krb5-server krb5-workstation krb5-libs krb5-devel elif [ "${ID}" = "alpine" ]; then $SUDO apk add krb5 krb5-server krb5-dev else echo "install-krb5.sh: Unsupported linux distro: ${distro}" >&2 exit 1 fi else echo "install-krb5.sh: unsupported OS: ${OSTYPE}" >&2 exit 1 fi ================================================ FILE: .github/workflows/install-postgres.sh ================================================ #!/bin/bash set -Eexuo pipefail shopt -s nullglob if [[ $OSTYPE == linux* ]]; then PGVERSION=${PGVERSION:-12} if [ -e /etc/os-release ]; then source /etc/os-release elif [ -e /etc/centos-release ]; then ID="centos" VERSION_ID=$(cat /etc/centos-release | cut -f3 -d' ' | cut -f1 -d.) else echo "install-postgres.sh: cannot determine which Linux distro this is" >&2 exit 1 fi if [ "${ID}" = "debian" -o "${ID}" = "ubuntu" ]; then export DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends curl gnupg ca-certificates curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - mkdir -p /etc/apt/sources.list.d/ echo "deb https://apt.postgresql.org/pub/repos/apt/ ${VERSION_CODENAME}-pgdg main" \ >> /etc/apt/sources.list.d/pgdg.list apt-get update apt-get install -y --no-install-recommends \ "postgresql-${PGVERSION}" \ "postgresql-contrib-${PGVERSION}" elif [ "${ID}" = "almalinux" ]; then yum install -y \ "postgresql-server" \ "postgresql-devel" \ "postgresql-contrib" elif [ "${ID}" = "centos" ]; then el="EL-${VERSION_ID%.*}-$(arch)" baseurl="https://download.postgresql.org/pub/repos/yum/reporpms" yum install -y "${baseurl}/${el}/pgdg-redhat-repo-latest.noarch.rpm" if [ ${VERSION_ID%.*} -ge 8 ]; then dnf -qy module disable postgresql fi yum install -y \ "postgresql${PGVERSION}-server" \ "postgresql${PGVERSION}-contrib" ln -s "/usr/pgsql-${PGVERSION}/bin/pg_config" "/usr/local/bin/pg_config" elif [ "${ID}" = "alpine" ]; then apk add shadow postgresql postgresql-dev postgresql-contrib else echo "install-postgres.sh: unsupported Linux distro: ${distro}" >&2 exit 1 fi useradd -m -s /bin/bash apgtest elif [[ $OSTYPE == darwin* ]]; then brew install postgresql else echo "install-postgres.sh: unsupported OS: ${OSTYPE}" >&2 exit 1 fi ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: pull_request: branches: - "master" - "ci" - "[0-9]+.[0-9x]+*" paths: - "asyncpg/_version.py" jobs: validate-release-request: runs-on: ubuntu-latest steps: - name: Validate release PR uses: edgedb/action-release/validate-pr@master id: checkver with: require_team: Release Managers require_approval: no github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} version_file: asyncpg/_version.py version_line_pattern: | __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - name: Stop if not approved if: steps.checkver.outputs.approved != 'true' run: | echo ::error::PR is not approved yet. exit 1 - name: Store release version for later use env: VERSION: ${{ steps.checkver.outputs.version }} run: | mkdir -p dist/ echo "${VERSION}" > dist/VERSION - uses: actions/upload-artifact@v4 with: name: dist-version path: dist/VERSION build-sdist: needs: validate-release-request runs-on: ubuntu-latest env: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - uses: actions/checkout@v5 with: fetch-depth: 50 submodules: true persist-credentials: false - name: Set up Python uses: actions/setup-python@v6 with: python-version: "3.x" - name: Build source distribution run: | pip install -U setuptools wheel pip python setup.py sdist - uses: actions/upload-artifact@v4 with: name: dist-sdist path: dist/*.tar.* build-wheels-matrix: needs: validate-release-request runs-on: ubuntu-latest outputs: include: ${{ steps.set-matrix.outputs.include }} steps: - uses: actions/checkout@v5 with: persist-credentials: false - uses: actions/setup-python@v6 with: python-version: "3.x" - run: pip install cibuildwheel==3.3.0 - id: set-matrix run: | MATRIX_INCLUDE=$( { cibuildwheel --print-build-identifiers --platform linux --archs x86_64,aarch64 | grep cp | jq -nRc '{"only": inputs, "os": "ubuntu-latest"}' \ && cibuildwheel --print-build-identifiers --platform macos --archs x86_64,arm64 | grep cp | jq -nRc '{"only": inputs, "os": "macos-latest"}' \ && cibuildwheel --print-build-identifiers --platform windows --archs x86,AMD64 | grep cp | jq -nRc '{"only": inputs, "os": "windows-latest"}' } | jq -sc ) echo "include=$MATRIX_INCLUDE" >> $GITHUB_OUTPUT build-wheels: needs: build-wheels-matrix runs-on: ${{ matrix.os }} name: Build ${{ matrix.only }} strategy: fail-fast: false matrix: include: ${{ fromJson(needs.build-wheels-matrix.outputs.include) }} defaults: run: shell: bash env: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - uses: actions/checkout@v5 with: fetch-depth: 50 submodules: true persist-credentials: false - name: Set up QEMU if: runner.os == 'Linux' uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0 - uses: pypa/cibuildwheel@63fd63b352a9a8bdcc24791c9dbee952ee9a8abc # v3.3.0 with: only: ${{ matrix.only }} env: CIBW_BUILD_VERBOSITY: 1 - uses: actions/upload-artifact@v4 with: name: dist-wheels-${{ matrix.only }} path: wheelhouse/*.whl merge-artifacts: runs-on: ubuntu-latest needs: [build-sdist, build-wheels] steps: - name: Merge Artifacts uses: actions/upload-artifact/merge@v4 with: name: dist delete-merged: true publish-docs: needs: [build-sdist, build-wheels] runs-on: ubuntu-latest env: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - name: Checkout source uses: actions/checkout@v5 with: fetch-depth: 5 submodules: true persist-credentials: false - name: Set up Python uses: actions/setup-python@v6 with: python-version: "3.x" - name: Build docs run: | pip install --group docs pip install -e . make htmldocs - name: Checkout gh-pages uses: actions/checkout@v5 with: fetch-depth: 5 ref: gh-pages path: docs/gh-pages persist-credentials: false - name: Sync docs run: | rsync -a docs/_build/html/ docs/gh-pages/current/ - name: Commit and push docs uses: magicstack/gha-commit-and-push@master with: target_branch: gh-pages workdir: docs/gh-pages commit_message: Automatic documentation update github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} ssh_key: ${{ secrets.RELEASE_BOT_SSH_KEY }} gpg_key: ${{ secrets.RELEASE_BOT_GPG_KEY }} gpg_key_id: "5C468778062D87BF!" publish: needs: [build-sdist, build-wheels, publish-docs] runs-on: ubuntu-latest environment: name: pypi url: https://pypi.org/p/asyncpg permissions: id-token: write attestations: write contents: write deployments: write steps: - uses: actions/checkout@v5 with: fetch-depth: 5 submodules: false persist-credentials: false - uses: actions/download-artifact@v4 with: name: dist path: dist/ - name: Extract Release Version id: relver run: | set -e echo "version=$(cat dist/VERSION)" >> $GITHUB_OUTPUT rm dist/VERSION - name: Merge and tag the PR uses: edgedb/action-release/merge@master with: github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} ssh_key: ${{ secrets.RELEASE_BOT_SSH_KEY }} gpg_key: ${{ secrets.RELEASE_BOT_GPG_KEY }} gpg_key_id: "5C468778062D87BF!" tag_name: v${{ steps.relver.outputs.version }} - name: Publish Github Release uses: elprans/gh-action-create-release@master env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tag_name: v${{ steps.relver.outputs.version }} release_name: v${{ steps.relver.outputs.version }} target: ${{ github.event.pull_request.base.ref }} body: ${{ github.event.pull_request.body }} - run: | ls -al dist/ - name: Upload to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: attestations: true ================================================ FILE: .github/workflows/tests.yml ================================================ name: Tests on: push: branches: - master - ci pull_request: branches: - master jobs: test-platforms: # NOTE: this matrix is for testing various combinations of Python and OS # versions on the system-installed PostgreSQL version (which is usually # fairly recent). For a PostgreSQL version matrix see the test-postgres # job. strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"] os: [ubuntu-latest, macos-latest, windows-latest] loop: [asyncio, uvloop] exclude: # uvloop does not support windows - loop: uvloop os: windows-latest runs-on: ${{ matrix.os }} permissions: {} defaults: run: shell: bash env: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - uses: actions/checkout@v5 with: fetch-depth: 50 submodules: true persist-credentials: false - name: Check if release PR. uses: edgedb/action-release/validate-pr@master id: release with: github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} missing_version_ok: yes version_file: asyncpg/_version.py version_line_pattern: | __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - name: Setup PostgreSQL if: "!steps.release.outputs.is_release && matrix.os == 'macos-latest'" run: | POSTGRES_FORMULA="postgresql@18" brew install "$POSTGRES_FORMULA" echo "$(brew --prefix "$POSTGRES_FORMULA")/bin" >> $GITHUB_PATH - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v6 if: "!steps.release.outputs.is_release" with: python-version: ${{ matrix.python-version }} - name: Install Python Deps if: "!steps.release.outputs.is_release" run: | [ "$RUNNER_OS" = "Linux" ] && .github/workflows/install-krb5.sh python -m pip install -U pip setuptools wheel python -m pip install --group test python -m pip install -e . - name: Test if: "!steps.release.outputs.is_release" env: LOOP_IMPL: ${{ matrix.loop }} run: | if [ "${LOOP_IMPL}" = "uvloop" ]; then env USE_UVLOOP=1 python -m unittest -v tests.suite else python -m unittest -v tests.suite fi test-postgres: strategy: matrix: postgres-version: ["9.5", "9.6", "10", "11", "12", "13", "14", "15", "16", "17", "18"] runs-on: ubuntu-latest permissions: {} env: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - uses: actions/checkout@v5 with: fetch-depth: 50 submodules: true persist-credentials: false - name: Check if release PR. uses: edgedb/action-release/validate-pr@master id: release with: github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }} missing_version_ok: yes version_file: asyncpg/_version.py version_line_pattern: | __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"]) - name: Set up PostgreSQL if: "!steps.release.outputs.is_release" env: PGVERSION: ${{ matrix.postgres-version }} DISTRO_NAME: focal run: | sudo env DISTRO_NAME="${DISTRO_NAME}" PGVERSION="${PGVERSION}" \ .github/workflows/install-postgres.sh echo PGINSTALLATION="/usr/lib/postgresql/${PGVERSION}/bin" \ >> "${GITHUB_ENV}" - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v6 if: "!steps.release.outputs.is_release" with: python-version: "3.x" - name: Install Python Deps if: "!steps.release.outputs.is_release" run: | [ "$RUNNER_OS" = "Linux" ] && .github/workflows/install-krb5.sh python -m pip install -U pip setuptools wheel python -m pip install --group test python -m pip install -e . - name: Test if: "!steps.release.outputs.is_release" env: PGVERSION: ${{ matrix.postgres-version }} run: | python -m unittest -v tests.suite # This job exists solely to act as the test job aggregate to be # targeted by branch policies. regression-tests: name: "Regression Tests" needs: [test-platforms, test-postgres] runs-on: ubuntu-latest permissions: {} steps: - run: echo OK ================================================ FILE: .gitignore ================================================ *._* *.pyc *.pyo *.ymlc *.ymlc~ *.scssc *.so *.pyd *~ .#* .DS_Store .project .pydevproject .settings .idea /.ropeproject \#*# /pub /test*.py /.local /perf.data* /config_local.yml /build __pycache__/ .d8_history /*.egg /*.egg-info /dist /.cache docs/_build *,cover .coverage /.pytest_cache/ /.eggs /.vscode /.zed /.mypy_cache /.venv* /.tox /compile_commands.json ================================================ FILE: .gitmodules ================================================ [submodule "asyncpg/pgproto"] path = asyncpg/pgproto url = https://github.com/MagicStack/py-pgproto.git ================================================ FILE: AUTHORS ================================================ Main contributors ================= MagicStack Inc.: Elvis Pranskevichus Yury Selivanov ================================================ FILE: LICENSE ================================================ Copyright (C) 2016-present the asyncpg authors and contributors. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright (C) 2016-present the asyncpg authors and contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ recursive-include docs *.py *.rst Makefile *.css recursive-include examples *.py recursive-include tests *.py *.pem recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h include LICENSE README.rst Makefile performance.png .flake8 ================================================ FILE: Makefile ================================================ .PHONY: compile debug test quicktest clean all PYTHON ?= python ROOT = $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) all: compile clean: rm -fr dist/ doc/_build/ rm -fr asyncpg/pgproto/*.c asyncpg/pgproto/*.html rm -fr asyncpg/pgproto/codecs/*.html rm -fr asyncpg/pgproto/*.so rm -fr asyncpg/protocol/*.c asyncpg/protocol/*.html rm -fr asyncpg/protocol/*.so build *.egg-info rm -fr asyncpg/protocol/codecs/*.html find . -name '__pycache__' | xargs rm -rf compile: env ASYNCPG_BUILD_CYTHON_ALWAYS=1 $(PYTHON) -m pip install -e . debug: env ASYNCPG_DEBUG=1 $(PYTHON) -m pip install -e . test: PYTHONASYNCIODEBUG=1 $(PYTHON) -m unittest -v tests.suite $(PYTHON) -m unittest -v tests.suite USE_UVLOOP=1 $(PYTHON) -m unittest -v tests.suite testinstalled: cd "$${HOME}" && $(PYTHON) $(ROOT)/tests/__init__.py quicktest: $(PYTHON) -m unittest -v tests.suite htmldocs: $(PYTHON) -m pip install -e .[docs] $(MAKE) -C docs html ================================================ FILE: README.rst ================================================ asyncpg -- A fast PostgreSQL Database Client Library for Python/asyncio ======================================================================= .. image:: https://github.com/MagicStack/asyncpg/workflows/Tests/badge.svg :target: https://github.com/MagicStack/asyncpg/actions?query=workflow%3ATests+branch%3Amaster :alt: GitHub Actions status .. image:: https://img.shields.io/pypi/v/asyncpg.svg :target: https://pypi.python.org/pypi/asyncpg **asyncpg** is a database interface library designed specifically for PostgreSQL and Python/asyncio. asyncpg is an efficient, clean implementation of PostgreSQL server binary protocol for use with Python's ``asyncio`` framework. You can read more about asyncpg in an introductory `blog post `_. asyncpg requires Python 3.9 or later and is supported for PostgreSQL versions 9.5 to 18. Other PostgreSQL versions or other databases implementing the PostgreSQL protocol *may* work, but are not being actively tested. Documentation ------------- The project documentation can be found `here `_. Performance ----------- In our testing asyncpg is, on average, **5x** faster than psycopg3. .. image:: https://raw.githubusercontent.com/MagicStack/asyncpg/master/performance.png?fddca40ab0 :target: https://gistpreview.github.io/?0ed296e93523831ea0918d42dd1258c2 The above results are a geometric mean of benchmarks obtained with PostgreSQL `client driver benchmarking toolbench `_ in June 2023 (click on the chart to see full details). Features -------- asyncpg implements PostgreSQL server protocol natively and exposes its features directly, as opposed to hiding them behind a generic facade like DB-API. This enables asyncpg to have easy-to-use support for: * **prepared statements** * **scrollable cursors** * **partial iteration** on query results * automatic encoding and decoding of composite types, arrays, and any combination of those * straightforward support for custom data types Installation ------------ asyncpg is available on PyPI. When not using GSSAPI/SSPI authentication it has no dependencies. Use pip to install:: $ pip install asyncpg If you need GSSAPI/SSPI authentication, use:: $ pip install 'asyncpg[gssauth]' For more details, please `see the documentation `_. Basic Usage ----------- .. code-block:: python import asyncio import asyncpg async def run(): conn = await asyncpg.connect(user='user', password='password', database='database', host='127.0.0.1') values = await conn.fetch( 'SELECT * FROM mytable WHERE id = $1', 10, ) await conn.close() asyncio.run(run()) License ------- asyncpg is developed and distributed under the Apache 2.0 license. ================================================ FILE: asyncpg/.gitignore ================================================ *.html ================================================ FILE: asyncpg/__init__.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations from .connection import connect, Connection # NOQA from .exceptions import * # NOQA from .pool import create_pool, Pool # NOQA from .protocol import Record # NOQA from .types import * # NOQA from ._version import __version__ # NOQA from . import exceptions __all__: tuple[str, ...] = ( 'connect', 'create_pool', 'Pool', 'Record', 'Connection' ) __all__ += exceptions.__all__ # NOQA ================================================ FILE: asyncpg/_asyncio_compat.py ================================================ # Backports from Python/Lib/asyncio for older Pythons # # Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved # # SPDX-License-Identifier: PSF-2.0 from __future__ import annotations import asyncio import functools import sys import typing if typing.TYPE_CHECKING: from . import compat if sys.version_info < (3, 11): from async_timeout import timeout as timeout_ctx else: from asyncio import timeout as timeout_ctx _T = typing.TypeVar('_T') async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T: """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. Returns result of the Future or coroutine. When a timeout occurs, it cancels the task and raises TimeoutError. To avoid the task cancellation, wrap it in shield(). If the wait is cancelled, the task is also cancelled. If the task supresses the cancellation and returns a value instead, that value is returned. This function is a coroutine. """ # The special case for timeout <= 0 is for the following case: # # async def test_waitfor(): # func_started = False # # async def func(): # nonlocal func_started # func_started = True # # try: # await asyncio.wait_for(func(), 0) # except asyncio.TimeoutError: # assert not func_started # else: # assert False # # asyncio.run(test_waitfor()) if timeout is not None and timeout <= 0: fut = asyncio.ensure_future(fut) if fut.done(): return fut.result() await _cancel_and_wait(fut) try: return fut.result() except asyncio.CancelledError as exc: raise TimeoutError from exc async with timeout_ctx(timeout): return await fut async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None: """Cancel the *fut* future or task and wait until it completes.""" loop = asyncio.get_running_loop() waiter = loop.create_future() cb = functools.partial(_release_waiter, waiter) fut.add_done_callback(cb) try: fut.cancel() # We cannot wait on *fut* directly to make # sure _cancel_and_wait itself is reliably cancellable. await waiter finally: fut.remove_done_callback(cb) def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None: if not waiter.done(): waiter.set_result(None) ================================================ FILE: asyncpg/_testbase/__init__.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import atexit import contextlib import functools import inspect import logging import os import re import textwrap import time import traceback import unittest import asyncpg from asyncpg import cluster as pg_cluster from asyncpg import connection as pg_connection from asyncpg import pool as pg_pool from . import fuzzer @contextlib.contextmanager def silence_asyncio_long_exec_warning(): def flt(log_record): msg = log_record.getMessage() return not msg.startswith('Executing ') logger = logging.getLogger('asyncio') logger.addFilter(flt) try: yield finally: logger.removeFilter(flt) def with_timeout(timeout): def wrap(func): func.__timeout__ = timeout return func return wrap class TestCaseMeta(type(unittest.TestCase)): TEST_TIMEOUT = None @staticmethod def _iter_methods(bases, ns): for base in bases: for methname in dir(base): if not methname.startswith('test_'): continue meth = getattr(base, methname) if not inspect.iscoroutinefunction(meth): continue yield methname, meth for methname, meth in ns.items(): if not methname.startswith('test_'): continue if not inspect.iscoroutinefunction(meth): continue yield methname, meth def __new__(mcls, name, bases, ns): for methname, meth in mcls._iter_methods(bases, ns): @functools.wraps(meth) def wrapper(self, *args, __meth__=meth, **kwargs): coro = __meth__(self, *args, **kwargs) timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT) if timeout: coro = asyncio.wait_for(coro, timeout) try: self.loop.run_until_complete(coro) except asyncio.TimeoutError: raise self.failureException( 'test timed out after {} seconds'.format( timeout)) from None else: self.loop.run_until_complete(coro) ns[methname] = wrapper return super().__new__(mcls, name, bases, ns) class TestCase(unittest.TestCase, metaclass=TestCaseMeta): @classmethod def setUpClass(cls): if os.environ.get('USE_UVLOOP'): import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) loop = asyncio.new_event_loop() asyncio.set_event_loop(None) cls.loop = loop @classmethod def tearDownClass(cls): cls.loop.close() asyncio.set_event_loop(None) def setUp(self): self.loop.set_exception_handler(self.loop_exception_handler) self.__unhandled_exceptions = [] def tearDown(self): excs = [] for exc in self.__unhandled_exceptions: if isinstance(exc, ConnectionResetError): texc = traceback.TracebackException.from_exception( exc, lookup_lines=False) if texc.stack[-1].name == "_call_connection_lost": # On Windows calling socket.shutdown may raise # ConnectionResetError, which happens in the # finally block of _call_connection_lost. continue excs.append(exc) if excs: formatted = [] for i, context in enumerate(excs): formatted.append(self._format_loop_exception(context, i + 1)) self.fail( 'unexpected exceptions in asynchronous code:\n' + '\n'.join(formatted)) @contextlib.contextmanager def assertRunUnder(self, delta): st = time.monotonic() try: yield finally: elapsed = time.monotonic() - st if elapsed > delta: raise AssertionError( 'running block took {:0.3f}s which is longer ' 'than the expected maximum of {:0.3f}s'.format( elapsed, delta)) @contextlib.contextmanager def assertLoopErrorHandlerCalled(self, msg_re: str): contexts = [] def handler(loop, ctx): contexts.append(ctx) old_handler = self.loop.get_exception_handler() self.loop.set_exception_handler(handler) try: yield for ctx in contexts: msg = ctx.get('message') if msg and re.search(msg_re, msg): return raise AssertionError( 'no message matching {!r} was logged with ' 'loop.call_exception_handler()'.format(msg_re)) finally: self.loop.set_exception_handler(old_handler) def loop_exception_handler(self, loop, context): self.__unhandled_exceptions.append(context) loop.default_exception_handler(context) def _format_loop_exception(self, context, n): message = context.get('message', 'Unhandled exception in event loop') exception = context.get('exception') if exception is not None: exc_info = (type(exception), exception, exception.__traceback__) else: exc_info = None lines = [] for key in sorted(context): if key in {'message', 'exception'}: continue value = context[key] if key == 'source_traceback': tb = ''.join(traceback.format_list(value)) value = 'Object created at (most recent call last):\n' value += tb.rstrip() else: try: value = repr(value) except Exception as ex: value = ('Exception in __repr__ {!r}; ' 'value type: {!r}'.format(ex, type(value))) lines.append('[{}]: {}\n\n'.format(key, value)) if exc_info is not None: lines.append('[exception]:\n') formatted_exc = textwrap.indent( ''.join(traceback.format_exception(*exc_info)), ' ') lines.append(formatted_exc) details = textwrap.indent(''.join(lines), ' ') return '{:02d}. {}:\n{}\n'.format(n, message, details) _default_cluster = None def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None): cluster = ClusterCls(**cluster_kwargs) cluster.init(**(initdb_options or {})) cluster.trust_local_connections() atexit.register(_shutdown_cluster, cluster) return cluster def _get_initdb_options(initdb_options=None): if not initdb_options: initdb_options = {} else: initdb_options = dict(initdb_options) # Make the default superuser name stable. if 'username' not in initdb_options: initdb_options['username'] = 'postgres' return initdb_options def _init_default_cluster(initdb_options=None): global _default_cluster if _default_cluster is None: pg_host = os.environ.get('PGHOST') if pg_host: # Using existing cluster, assuming it is initialized and running _default_cluster = pg_cluster.RunningCluster() else: _default_cluster = _init_cluster( pg_cluster.TempCluster, cluster_kwargs={ "data_dir_suffix": ".apgtest", }, initdb_options=_get_initdb_options(initdb_options), ) return _default_cluster def _shutdown_cluster(cluster): if cluster.get_status() == 'running': cluster.stop() if cluster.get_status() != 'not-initialized': cluster.destroy() def create_pool(dsn=None, *, min_size=10, max_size=10, max_queries=50000, max_inactive_connection_lifetime=60.0, connect=None, setup=None, init=None, loop=None, pool_class=pg_pool.Pool, connection_class=pg_connection.Connection, record_class=asyncpg.Record, **connect_kwargs): return pool_class( dsn, min_size=min_size, max_size=max_size, max_queries=max_queries, loop=loop, connect=connect, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, connection_class=connection_class, record_class=record_class, **connect_kwargs, ) class ClusterTestCase(TestCase): @classmethod def get_server_settings(cls): settings = { 'log_connections': 'on' } if cls.cluster.get_pg_version() >= (11, 0): # JITting messes up timing tests, and # is not essential for testing. settings['jit'] = 'off' return settings @classmethod def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}): cluster = _init_cluster(ClusterCls, cluster_kwargs, _get_initdb_options(initdb_options)) cls._clusters.append(cluster) return cluster @classmethod def start_cluster(cls, cluster, *, server_settings={}): cluster.start(port='dynamic', server_settings=server_settings) @classmethod def setup_cluster(cls): cls.cluster = _init_default_cluster() if cls.cluster.get_status() != 'running': cls.cluster.start( port='dynamic', server_settings=cls.get_server_settings()) @classmethod def setUpClass(cls): super().setUpClass() cls._clusters = [] cls.setup_cluster() @classmethod def tearDownClass(cls): super().tearDownClass() for cluster in cls._clusters: if cluster is not _default_cluster: cluster.stop() cluster.destroy() cls._clusters = [] @classmethod def get_connection_spec(cls, kwargs={}): conn_spec = cls.cluster.get_connection_spec() if kwargs.get('dsn'): conn_spec.pop('host') conn_spec.update(kwargs) if not os.environ.get('PGHOST') and not kwargs.get('dsn'): if 'database' not in conn_spec: conn_spec['database'] = 'postgres' if 'user' not in conn_spec: conn_spec['user'] = 'postgres' return conn_spec @classmethod def connect(cls, **kwargs): conn_spec = cls.get_connection_spec(kwargs) return pg_connection.connect(**conn_spec, loop=cls.loop) def setUp(self): super().setUp() self._pools = [] def tearDown(self): super().tearDown() for pool in self._pools: pool.terminate() self._pools = [] def create_pool(self, pool_class=pg_pool.Pool, connection_class=pg_connection.Connection, **kwargs): conn_spec = self.get_connection_spec(kwargs) pool = create_pool(loop=self.loop, pool_class=pool_class, connection_class=connection_class, **conn_spec) self._pools.append(pool) return pool class ProxiedClusterTestCase(ClusterTestCase): @classmethod def get_server_settings(cls): settings = dict(super().get_server_settings()) settings['listen_addresses'] = '127.0.0.1' return settings @classmethod def get_proxy_settings(cls): return {'fuzzing-mode': None} @classmethod def setUpClass(cls): super().setUpClass() conn_spec = cls.cluster.get_connection_spec() host = conn_spec.get('host') if not host: host = '127.0.0.1' elif host.startswith('/'): host = '127.0.0.1' cls.proxy = fuzzer.TCPFuzzingProxy( backend_host=host, backend_port=conn_spec['port'], ) cls.proxy.start() @classmethod def tearDownClass(cls): cls.proxy.stop() super().tearDownClass() @classmethod def get_connection_spec(cls, kwargs): conn_spec = super().get_connection_spec(kwargs) conn_spec['host'] = cls.proxy.listening_addr conn_spec['port'] = cls.proxy.listening_port return conn_spec def tearDown(self): self.proxy.reset() super().tearDown() def with_connection_options(**options): if not options: raise ValueError('no connection options were specified') def wrap(func): func.__connect_options__ = options return func return wrap class ConnectedTestCase(ClusterTestCase): def setUp(self): super().setUp() # Extract options set up with `with_connection_options`. test_func = getattr(self, self._testMethodName).__func__ opts = getattr(test_func, '__connect_options__', {}) self.con = self.loop.run_until_complete(self.connect(**opts)) self.server_version = self.con.get_server_version() def tearDown(self): try: self.loop.run_until_complete(self.con.close()) self.con = None finally: super().tearDown() class HotStandbyTestCase(ClusterTestCase): @classmethod def setup_cluster(cls): cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) cls.start_cluster( cls.master_cluster, server_settings={ 'max_wal_senders': 10, 'wal_level': 'hot_standby' } ) con = None try: con = cls.loop.run_until_complete( cls.master_cluster.connect( database='postgres', user='postgres', loop=cls.loop)) cls.loop.run_until_complete( con.execute(''' CREATE ROLE replication WITH LOGIN REPLICATION ''')) cls.master_cluster.trust_local_replication_by('replication') conn_spec = cls.master_cluster.get_connection_spec() cls.standby_cluster = cls.new_cluster( pg_cluster.HotStandbyCluster, cluster_kwargs={ 'master': conn_spec, 'replication_user': 'replication' } ) cls.start_cluster( cls.standby_cluster, server_settings={ 'hot_standby': True } ) finally: if con is not None: cls.loop.run_until_complete(con.close()) @classmethod def get_cluster_connection_spec(cls, cluster, kwargs={}): conn_spec = cluster.get_connection_spec() if kwargs.get('dsn'): conn_spec.pop('host') conn_spec.update(kwargs) if not os.environ.get('PGHOST') and not kwargs.get('dsn'): if 'database' not in conn_spec: conn_spec['database'] = 'postgres' if 'user' not in conn_spec: conn_spec['user'] = 'postgres' return conn_spec @classmethod def get_connection_spec(cls, kwargs={}): primary_spec = cls.get_cluster_connection_spec( cls.master_cluster, kwargs ) standby_spec = cls.get_cluster_connection_spec( cls.standby_cluster, kwargs ) return { 'host': [primary_spec['host'], standby_spec['host']], 'port': [primary_spec['port'], standby_spec['port']], 'database': primary_spec['database'], 'user': primary_spec['user'], **kwargs } @classmethod def connect_primary(cls, **kwargs): conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs) return pg_connection.connect(**conn_spec, loop=cls.loop) @classmethod def connect_standby(cls, **kwargs): conn_spec = cls.get_cluster_connection_spec( cls.standby_cluster, kwargs ) return pg_connection.connect(**conn_spec, loop=cls.loop) ================================================ FILE: asyncpg/_testbase/fuzzer.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import socket import threading import typing from asyncpg import cluster class StopServer(Exception): pass class TCPFuzzingProxy: def __init__(self, *, listening_addr: str='127.0.0.1', listening_port: typing.Optional[int]=None, backend_host: str, backend_port: int, settings: typing.Optional[dict]=None) -> None: self.listening_addr = listening_addr self.listening_port = listening_port self.backend_host = backend_host self.backend_port = backend_port self.settings = settings or {} self.loop = None self.connectivity = None self.connectivity_loss = None self.stop_event = None self.connections = {} self.sock = None self.listen_task = None async def _wait(self, work): work_task = asyncio.ensure_future(work) stop_event_task = asyncio.ensure_future(self.stop_event.wait()) try: await asyncio.wait( [work_task, stop_event_task], return_when=asyncio.FIRST_COMPLETED) if self.stop_event.is_set(): raise StopServer() else: return work_task.result() finally: if not work_task.done(): work_task.cancel() if not stop_event_task.done(): stop_event_task.cancel() def start(self): started = threading.Event() self.thread = threading.Thread( target=self._start_thread, args=(started,)) self.thread.start() if not started.wait(timeout=2): raise RuntimeError('fuzzer proxy failed to start') def stop(self): self.loop.call_soon_threadsafe(self._stop) self.thread.join() def _stop(self): self.stop_event.set() def _start_thread(self, started_event): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.connectivity = asyncio.Event() self.connectivity.set() self.connectivity_loss = asyncio.Event() self.stop_event = asyncio.Event() if self.listening_port is None: self.listening_port = cluster.find_available_port() self.sock = socket.socket() self.sock.bind((self.listening_addr, self.listening_port)) self.sock.listen(50) self.sock.setblocking(False) try: self.loop.run_until_complete(self._main(started_event)) finally: self.loop.close() async def _main(self, started_event): self.listen_task = asyncio.ensure_future(self.listen()) # Notify the main thread that we are ready to go. started_event.set() try: await self.listen_task finally: for c in list(self.connections): c.close() await asyncio.sleep(0.01) if hasattr(self.loop, 'remove_reader'): self.loop.remove_reader(self.sock.fileno()) self.sock.close() async def listen(self): while True: try: client_sock, _ = await self._wait( self.loop.sock_accept(self.sock)) backend_sock = socket.socket() backend_sock.setblocking(False) await self._wait(self.loop.sock_connect( backend_sock, (self.backend_host, self.backend_port))) except StopServer: break conn = Connection(client_sock, backend_sock, self) conn_task = self.loop.create_task(conn.handle()) self.connections[conn] = conn_task def trigger_connectivity_loss(self): self.loop.call_soon_threadsafe(self._trigger_connectivity_loss) def _trigger_connectivity_loss(self): self.connectivity.clear() self.connectivity_loss.set() def restore_connectivity(self): self.loop.call_soon_threadsafe(self._restore_connectivity) def _restore_connectivity(self): self.connectivity.set() self.connectivity_loss.clear() def reset(self): self.restore_connectivity() def _close_connection(self, connection): conn_task = self.connections.pop(connection, None) if conn_task is not None: conn_task.cancel() def close_all_connections(self): for conn in list(self.connections): self.loop.call_soon_threadsafe(self._close_connection, conn) class Connection: def __init__(self, client_sock, backend_sock, proxy): self.client_sock = client_sock self.backend_sock = backend_sock self.proxy = proxy self.loop = proxy.loop self.connectivity = proxy.connectivity self.connectivity_loss = proxy.connectivity_loss self.proxy_to_backend_task = None self.proxy_from_backend_task = None self.is_closed = False def close(self): if self.is_closed: return self.is_closed = True if self.proxy_to_backend_task is not None: self.proxy_to_backend_task.cancel() self.proxy_to_backend_task = None if self.proxy_from_backend_task is not None: self.proxy_from_backend_task.cancel() self.proxy_from_backend_task = None self.proxy._close_connection(self) async def handle(self): self.proxy_to_backend_task = asyncio.ensure_future( self.proxy_to_backend()) self.proxy_from_backend_task = asyncio.ensure_future( self.proxy_from_backend()) try: await asyncio.wait( [self.proxy_to_backend_task, self.proxy_from_backend_task], return_when=asyncio.FIRST_COMPLETED) finally: if self.proxy_to_backend_task is not None: self.proxy_to_backend_task.cancel() if self.proxy_from_backend_task is not None: self.proxy_from_backend_task.cancel() # Asyncio fails to properly remove the readers and writers # when the task doing recv() or send() is cancelled, so # we must remove the readers and writers manually before # closing the sockets. self.loop.remove_reader(self.client_sock.fileno()) self.loop.remove_writer(self.client_sock.fileno()) self.loop.remove_reader(self.backend_sock.fileno()) self.loop.remove_writer(self.backend_sock.fileno()) self.client_sock.close() self.backend_sock.close() async def _read(self, sock, n): read_task = asyncio.ensure_future( self.loop.sock_recv(sock, n)) conn_event_task = asyncio.ensure_future( self.connectivity_loss.wait()) try: await asyncio.wait( [read_task, conn_event_task], return_when=asyncio.FIRST_COMPLETED) if self.connectivity_loss.is_set(): return None else: return read_task.result() finally: if not self.loop.is_closed(): if not read_task.done(): read_task.cancel() if not conn_event_task.done(): conn_event_task.cancel() async def _write(self, sock, data): write_task = asyncio.ensure_future( self.loop.sock_sendall(sock, data)) conn_event_task = asyncio.ensure_future( self.connectivity_loss.wait()) try: await asyncio.wait( [write_task, conn_event_task], return_when=asyncio.FIRST_COMPLETED) if self.connectivity_loss.is_set(): return None else: return write_task.result() finally: if not self.loop.is_closed(): if not write_task.done(): write_task.cancel() if not conn_event_task.done(): conn_event_task.cancel() async def proxy_to_backend(self): buf = None try: while True: await self.connectivity.wait() if buf is not None: data = buf buf = None else: data = await self._read(self.client_sock, 4096) if data == b'': break if self.connectivity_loss.is_set(): if data: buf = data continue await self._write(self.backend_sock, data) except ConnectionError: pass finally: if not self.loop.is_closed(): self.loop.call_soon(self.close) async def proxy_from_backend(self): buf = None try: while True: await self.connectivity.wait() if buf is not None: data = buf buf = None else: data = await self._read(self.backend_sock, 4096) if data == b'': break if self.connectivity_loss.is_set(): if data: buf = data continue await self._write(self.client_sock, data) except ConnectionError: pass finally: if not self.loop.is_closed(): self.loop.call_soon(self.close) ================================================ FILE: asyncpg/_version.py ================================================ # This file MUST NOT contain anything but the __version__ assignment. # # When making a release, change the value of __version__ # to an appropriate value, and open a pull request against # the correct branch (master if making a new feature release). # The commit message MUST contain a properly formatted release # log, and the commit must be signed. # # The release automation will: build and test the packages for the # supported platforms, publish the packages on PyPI, merge the PR # to the target branch, create a Git tag pointing to the commit. from __future__ import annotations import typing __version__: typing.Final = '0.32.0.dev0' ================================================ FILE: asyncpg/cluster.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import os import os.path import platform import random import re import shutil import socket import string import subprocess import sys import tempfile import textwrap import time import asyncpg from asyncpg import serverversion _system = platform.uname().system if _system == 'Windows': def platform_exe(name): if name.endswith('.exe'): return name return name + '.exe' else: def platform_exe(name): return name def find_available_port(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.bind(('127.0.0.1', 0)) return sock.getsockname()[1] except Exception: return None finally: sock.close() def _world_readable_mkdtemp(suffix=None, prefix=None, dir=None): name = "".join(random.choices(string.ascii_lowercase, k=8)) if dir is None: dir = tempfile.gettempdir() if prefix is None: prefix = tempfile.gettempprefix() if suffix is None: suffix = "" fn = os.path.join(dir, prefix + name + suffix) os.mkdir(fn, 0o755) return fn def _mkdtemp(suffix=None, prefix=None, dir=None): if _system == 'Windows' and os.environ.get("GITHUB_ACTIONS"): # Due to mitigations introduced in python/cpython#118486 # when Python runs in a session created via an SSH connection # tempfile.mkdtemp creates directories that are not accessible. return _world_readable_mkdtemp(suffix, prefix, dir) else: return tempfile.mkdtemp(suffix, prefix, dir) class ClusterError(Exception): pass class Cluster: def __init__(self, data_dir, *, pg_config_path=None): self._data_dir = data_dir self._pg_config_path = pg_config_path self._pg_bin_dir = ( os.environ.get('PGINSTALLATION') or os.environ.get('PGBIN') ) self._pg_ctl = None self._daemon_pid = None self._daemon_process = None self._connection_addr = None self._connection_spec_override = None def get_pg_version(self): return self._pg_version def is_managed(self): return True def get_data_dir(self): return self._data_dir def get_status(self): if self._pg_ctl is None: self._init_env() process = subprocess.run( [self._pg_ctl, 'status', '-D', self._data_dir], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.stdout, process.stderr if (process.returncode == 4 or not os.path.exists(self._data_dir) or not os.listdir(self._data_dir)): return 'not-initialized' elif process.returncode == 3: return 'stopped' elif process.returncode == 0: r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode()) if not r: raise ClusterError( 'could not parse pg_ctl status output: {}'.format( stdout.decode())) self._daemon_pid = int(r.group(1)) return self._test_connection(timeout=0) else: raise ClusterError( 'pg_ctl status exited with status {:d}: {}'.format( process.returncode, stderr)) async def connect(self, loop=None, **kwargs): conn_info = self.get_connection_spec() conn_info.update(kwargs) return await asyncpg.connect(loop=loop, **conn_info) def init(self, **settings): """Initialize cluster.""" if self.get_status() != 'not-initialized': raise ClusterError( 'cluster in {!r} has already been initialized'.format( self._data_dir)) settings = dict(settings) if 'encoding' not in settings: settings['encoding'] = 'UTF-8' if settings: settings_args = ['--{}={}'.format(k, v) for k, v in settings.items()] extra_args = ['-o'] + [' '.join(settings_args)] else: extra_args = [] os.makedirs(self._data_dir, exist_ok=True) process = subprocess.run( [self._pg_ctl, 'init', '-D', self._data_dir] + extra_args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=self._data_dir, ) output = process.stdout if process.returncode != 0: raise ClusterError( 'pg_ctl init exited with status {:d}:\n{}'.format( process.returncode, output.decode())) return output.decode() def start(self, wait=60, *, server_settings={}, **opts): """Start the cluster.""" status = self.get_status() if status == 'running': return elif status == 'not-initialized': raise ClusterError( 'cluster in {!r} has not been initialized'.format( self._data_dir)) port = opts.pop('port', None) if port == 'dynamic': port = find_available_port() extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()] extra_args.append('--port={}'.format(port)) sockdir = server_settings.get('unix_socket_directories') if sockdir is None: sockdir = server_settings.get('unix_socket_directory') if sockdir is None and _system != 'Windows': sockdir = tempfile.gettempdir() ssl_key = server_settings.get('ssl_key_file') if ssl_key: # Make sure server certificate key file has correct permissions. keyfile = os.path.join(self._data_dir, 'srvkey.pem') shutil.copy(ssl_key, keyfile) os.chmod(keyfile, 0o600) server_settings = server_settings.copy() server_settings['ssl_key_file'] = keyfile if sockdir is not None: if self._pg_version < (9, 3): sockdir_opt = 'unix_socket_directory' else: sockdir_opt = 'unix_socket_directories' server_settings[sockdir_opt] = sockdir for k, v in server_settings.items(): extra_args.extend(['-c', '{}={}'.format(k, v)]) if _system == 'Windows': # On Windows we have to use pg_ctl as direct execution # of postgres daemon under an Administrative account # is not permitted and there is no easy way to drop # privileges. if os.getenv('ASYNCPG_DEBUG_SERVER'): stdout = sys.stdout print( 'asyncpg.cluster: Running', ' '.join([ self._pg_ctl, 'start', '-D', self._data_dir, '-o', ' '.join(extra_args) ]), file=sys.stderr, ) else: stdout = subprocess.DEVNULL process = subprocess.run( [self._pg_ctl, 'start', '-D', self._data_dir, '-o', ' '.join(extra_args)], stdout=stdout, stderr=subprocess.STDOUT, cwd=self._data_dir, ) if process.returncode != 0: if process.stderr: stderr = ':\n{}'.format(process.stderr.decode()) else: stderr = '' raise ClusterError( 'pg_ctl start exited with status {:d}{}'.format( process.returncode, stderr)) else: if os.getenv('ASYNCPG_DEBUG_SERVER'): stdout = sys.stdout else: stdout = subprocess.DEVNULL self._daemon_process = \ subprocess.Popen( [self._postgres, '-D', self._data_dir, *extra_args], stdout=stdout, stderr=subprocess.STDOUT, cwd=self._data_dir, ) self._daemon_pid = self._daemon_process.pid self._test_connection(timeout=wait) def reload(self): """Reload server configuration.""" status = self.get_status() if status != 'running': raise ClusterError('cannot reload: cluster is not running') process = subprocess.run( [self._pg_ctl, 'reload', '-D', self._data_dir], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self._data_dir, ) stderr = process.stderr if process.returncode != 0: raise ClusterError( 'pg_ctl stop exited with status {:d}: {}'.format( process.returncode, stderr.decode())) def stop(self, wait=60): process = subprocess.run( [self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait), '-m', 'fast'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self._data_dir, ) stderr = process.stderr if process.returncode != 0: raise ClusterError( 'pg_ctl stop exited with status {:d}: {}'.format( process.returncode, stderr.decode())) if (self._daemon_process is not None and self._daemon_process.returncode is None): self._daemon_process.kill() def destroy(self): status = self.get_status() if status == 'stopped' or status == 'not-initialized': shutil.rmtree(self._data_dir) else: raise ClusterError('cannot destroy {} cluster'.format(status)) def _get_connection_spec(self): if self._connection_addr is None: self._connection_addr = self._connection_addr_from_pidfile() if self._connection_addr is not None: if self._connection_spec_override: args = self._connection_addr.copy() args.update(self._connection_spec_override) return args else: return self._connection_addr def get_connection_spec(self): status = self.get_status() if status != 'running': raise ClusterError('cluster is not running') return self._get_connection_spec() def override_connection_spec(self, **kwargs): self._connection_spec_override = kwargs def reset_wal(self, *, oid=None, xid=None): status = self.get_status() if status == 'not-initialized': raise ClusterError( 'cannot modify WAL status: cluster is not initialized') if status == 'running': raise ClusterError( 'cannot modify WAL status: cluster is running') opts = [] if oid is not None: opts.extend(['-o', str(oid)]) if xid is not None: opts.extend(['-x', str(xid)]) if not opts: return opts.append(self._data_dir) try: reset_wal = self._find_pg_binary('pg_resetwal') except ClusterError: reset_wal = self._find_pg_binary('pg_resetxlog') process = subprocess.run( [reset_wal] + opts, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stderr = process.stderr if process.returncode != 0: raise ClusterError( 'pg_resetwal exited with status {:d}: {}'.format( process.returncode, stderr.decode())) def reset_hba(self): """Remove all records from pg_hba.conf.""" status = self.get_status() if status == 'not-initialized': raise ClusterError( 'cannot modify HBA records: cluster is not initialized') pg_hba = os.path.join(self._data_dir, 'pg_hba.conf') try: with open(pg_hba, 'w'): pass except IOError as e: raise ClusterError( 'cannot modify HBA records: {}'.format(e)) from e def add_hba_entry(self, *, type='host', database, user, address=None, auth_method, auth_options=None): """Add a record to pg_hba.conf.""" status = self.get_status() if status == 'not-initialized': raise ClusterError( 'cannot modify HBA records: cluster is not initialized') if type not in {'local', 'host', 'hostssl', 'hostnossl'}: raise ValueError('invalid HBA record type: {!r}'.format(type)) pg_hba = os.path.join(self._data_dir, 'pg_hba.conf') record = '{} {} {}'.format(type, database, user) if type != 'local': if address is None: raise ValueError( '{!r} entry requires a valid address'.format(type)) else: record += ' {}'.format(address) record += ' {}'.format(auth_method) if auth_options is not None: record += ' ' + ' '.join( '{}={}'.format(k, v) for k, v in auth_options) try: with open(pg_hba, 'a') as f: print(record, file=f) except IOError as e: raise ClusterError( 'cannot modify HBA records: {}'.format(e)) from e def trust_local_connections(self): self.reset_hba() if _system != 'Windows': self.add_hba_entry(type='local', database='all', user='all', auth_method='trust') self.add_hba_entry(type='host', address='127.0.0.1/32', database='all', user='all', auth_method='trust') self.add_hba_entry(type='host', address='::1/128', database='all', user='all', auth_method='trust') status = self.get_status() if status == 'running': self.reload() def trust_local_replication_by(self, user): if _system != 'Windows': self.add_hba_entry(type='local', database='replication', user=user, auth_method='trust') self.add_hba_entry(type='host', address='127.0.0.1/32', database='replication', user=user, auth_method='trust') self.add_hba_entry(type='host', address='::1/128', database='replication', user=user, auth_method='trust') status = self.get_status() if status == 'running': self.reload() def _init_env(self): if not self._pg_bin_dir: pg_config = self._find_pg_config(self._pg_config_path) pg_config_data = self._run_pg_config(pg_config) self._pg_bin_dir = pg_config_data.get('bindir') if not self._pg_bin_dir: raise ClusterError( 'pg_config output did not provide the BINDIR value') self._pg_ctl = self._find_pg_binary('pg_ctl') self._postgres = self._find_pg_binary('postgres') self._pg_version = self._get_pg_version() def _connection_addr_from_pidfile(self): pidfile = os.path.join(self._data_dir, 'postmaster.pid') try: with open(pidfile, 'rt') as f: piddata = f.read() except FileNotFoundError: return None lines = piddata.splitlines() if len(lines) < 6: # A complete postgres pidfile is at least 6 lines return None pmpid = int(lines[0]) if self._daemon_pid and pmpid != self._daemon_pid: # This might be an old pidfile left from previous postgres # daemon run. return None portnum = lines[3] sockdir = lines[4] hostaddr = lines[5] if sockdir: if sockdir[0] != '/': # Relative sockdir sockdir = os.path.normpath( os.path.join(self._data_dir, sockdir)) host_str = sockdir else: host_str = hostaddr if host_str == '*': host_str = 'localhost' elif host_str == '0.0.0.0': host_str = '127.0.0.1' elif host_str == '::': host_str = '::1' return { 'host': host_str, 'port': portnum } def _test_connection(self, timeout=60): self._connection_addr = None loop = asyncio.new_event_loop() try: for i in range(timeout): if self._connection_addr is None: conn_spec = self._get_connection_spec() if conn_spec is None: time.sleep(1) continue try: con = loop.run_until_complete( asyncpg.connect(database='postgres', user='postgres', timeout=5, loop=loop, **self._connection_addr)) except (OSError, asyncio.TimeoutError, asyncpg.CannotConnectNowError, asyncpg.PostgresConnectionError): time.sleep(1) continue except asyncpg.PostgresError: # Any other error other than ServerNotReadyError or # ConnectionError is interpreted to indicate the server is # up. break else: loop.run_until_complete(con.close()) break finally: loop.close() return 'running' def _run_pg_config(self, pg_config_path): process = subprocess.run( pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.stdout, process.stderr if process.returncode != 0: raise ClusterError('pg_config exited with status {:d}: {}'.format( process.returncode, stderr)) else: config = {} for line in stdout.splitlines(): k, eq, v = line.decode('utf-8').partition('=') if eq: config[k.strip().lower()] = v.strip() return config def _find_pg_config(self, pg_config_path): if pg_config_path is None: pg_install = ( os.environ.get('PGINSTALLATION') or os.environ.get('PGBIN') ) if pg_install: pg_config_path = platform_exe( os.path.join(pg_install, 'pg_config')) else: pathenv = os.environ.get('PATH').split(os.pathsep) for path in pathenv: pg_config_path = platform_exe( os.path.join(path, 'pg_config')) if os.path.exists(pg_config_path): break else: pg_config_path = None if not pg_config_path: raise ClusterError('could not find pg_config executable') if not os.path.isfile(pg_config_path): raise ClusterError('{!r} is not an executable'.format( pg_config_path)) return pg_config_path def _find_pg_binary(self, binary): bpath = platform_exe(os.path.join(self._pg_bin_dir, binary)) if not os.path.isfile(bpath): raise ClusterError( 'could not find {} executable: '.format(binary) + '{!r} does not exist or is not a file'.format(bpath)) return bpath def _get_pg_version(self): process = subprocess.run( [self._postgres, '--version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.stdout, process.stderr if process.returncode != 0: raise ClusterError( 'postgres --version exited with status {:d}: {}'.format( process.returncode, stderr)) version_string = stdout.decode('utf-8').strip(' \n') prefix = 'postgres (PostgreSQL) ' if not version_string.startswith(prefix): raise ClusterError( 'could not determine server version from {!r}'.format( version_string)) version_string = version_string[len(prefix):] return serverversion.split_server_version_string(version_string) class TempCluster(Cluster): def __init__(self, *, data_dir_suffix=None, data_dir_prefix=None, data_dir_parent=None, pg_config_path=None): self._data_dir = _mkdtemp(suffix=data_dir_suffix, prefix=data_dir_prefix, dir=data_dir_parent) super().__init__(self._data_dir, pg_config_path=pg_config_path) class HotStandbyCluster(TempCluster): def __init__(self, *, master, replication_user, data_dir_suffix=None, data_dir_prefix=None, data_dir_parent=None, pg_config_path=None): self._master = master self._repl_user = replication_user super().__init__( data_dir_suffix=data_dir_suffix, data_dir_prefix=data_dir_prefix, data_dir_parent=data_dir_parent, pg_config_path=pg_config_path) def _init_env(self): super()._init_env() self._pg_basebackup = self._find_pg_binary('pg_basebackup') def init(self, **settings): """Initialize cluster.""" if self.get_status() != 'not-initialized': raise ClusterError( 'cluster in {!r} has already been initialized'.format( self._data_dir)) process = subprocess.run( [self._pg_basebackup, '-h', self._master['host'], '-p', self._master['port'], '-D', self._data_dir, '-U', self._repl_user], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) output = process.stdout if process.returncode != 0: raise ClusterError( 'pg_basebackup init exited with status {:d}:\n{}'.format( process.returncode, output.decode())) if self._pg_version < (12, 0): with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f: f.write(textwrap.dedent("""\ standby_mode = 'on' primary_conninfo = 'host={host} port={port} user={user}' """.format( host=self._master['host'], port=self._master['port'], user=self._repl_user))) else: f = open(os.path.join(self._data_dir, 'standby.signal'), 'w') f.close() return output.decode() def start(self, wait=60, *, server_settings={}, **opts): if self._pg_version >= (12, 0): server_settings = server_settings.copy() server_settings['primary_conninfo'] = ( '"host={host} port={port} user={user}"'.format( host=self._master['host'], port=self._master['port'], user=self._repl_user, ) ) super().start(wait=wait, server_settings=server_settings, **opts) class RunningCluster(Cluster): def __init__(self, **kwargs): self.conn_spec = kwargs def is_managed(self): return False def get_connection_spec(self): return dict(self.conn_spec) def get_status(self): return 'running' def init(self, **settings): pass def start(self, wait=60, **settings): pass def stop(self, wait=60): pass def destroy(self): pass def reset_hba(self): raise ClusterError('cannot modify HBA records of unmanaged cluster') def add_hba_entry(self, *, type='host', database, user, address=None, auth_method, auth_options=None): raise ClusterError('cannot modify HBA records of unmanaged cluster') ================================================ FILE: asyncpg/compat.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import enum import pathlib import platform import typing import sys if typing.TYPE_CHECKING: import asyncio SYSTEM: typing.Final = platform.uname().system if sys.platform == 'win32': import ctypes.wintypes CSIDL_APPDATA: typing.Final = 0x001a def get_pg_home_directory() -> pathlib.Path | None: # We cannot simply use expanduser() as that returns the user's # home directory, whereas Postgres stores its config in # %AppData% on Windows. buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH) r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf) if r: return None else: return pathlib.Path(buf.value) / 'postgresql' else: def get_pg_home_directory() -> pathlib.Path | None: try: return pathlib.Path.home() except (RuntimeError, KeyError): return None async def wait_closed(stream: asyncio.StreamWriter) -> None: # Not all asyncio versions have StreamWriter.wait_closed(). if hasattr(stream, 'wait_closed'): try: await stream.wait_closed() except ConnectionResetError: # On Windows wait_closed() sometimes propagates # ConnectionResetError which is totally unnecessary. pass if sys.version_info < (3, 12): def markcoroutinefunction(c): # type: ignore pass else: from inspect import markcoroutinefunction # noqa: F401 if sys.version_info < (3, 12): from ._asyncio_compat import wait_for as wait_for # noqa: F401 else: from asyncio import wait_for as wait_for # noqa: F401 if sys.version_info < (3, 11): from ._asyncio_compat import timeout_ctx as timeout # noqa: F401 else: from asyncio import timeout as timeout # noqa: F401 if sys.version_info < (3, 9): from typing import ( # noqa: F401 Awaitable as Awaitable, ) else: from collections.abc import ( # noqa: F401 Awaitable as Awaitable, ) if sys.version_info < (3, 11): class StrEnum(str, enum.Enum): __str__ = str.__str__ __repr__ = enum.Enum.__repr__ else: from enum import StrEnum as StrEnum # noqa: F401 ================================================ FILE: asyncpg/connect_utils.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import asyncio import configparser import collections from collections.abc import Callable import enum import functools import getpass import os import pathlib import platform import random import re import socket import ssl as ssl_module import stat import struct import sys import typing import urllib.parse import warnings import inspect from . import compat from . import exceptions from . import protocol class SSLMode(enum.IntEnum): disable = 0 allow = 1 prefer = 2 require = 3 verify_ca = 4 verify_full = 5 @classmethod def parse(cls, sslmode): if isinstance(sslmode, cls): return sslmode return getattr(cls, sslmode.replace('-', '_')) class SSLNegotiation(compat.StrEnum): postgres = "postgres" direct = "direct" _ConnectionParameters = collections.namedtuple( 'ConnectionParameters', [ 'user', 'password', 'database', 'ssl', 'sslmode', 'ssl_negotiation', 'server_settings', 'target_session_attrs', 'krbsrvname', 'gsslib', ]) _ClientConfiguration = collections.namedtuple( 'ConnectionConfiguration', [ 'command_timeout', 'statement_cache_size', 'max_cached_statement_lifetime', 'max_cacheable_statement_size', ]) _system = platform.uname().system if _system == 'Windows': PGPASSFILE = 'pgpass.conf' else: PGPASSFILE = '.pgpass' PG_SERVICEFILE = '.pg_service.conf' def _read_password_file(passfile: pathlib.Path) \ -> typing.List[typing.Tuple[str, ...]]: passtab = [] try: if not passfile.exists(): return [] if not passfile.is_file(): warnings.warn( 'password file {!r} is not a plain file'.format(passfile)) return [] if _system != 'Windows': if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO): warnings.warn( 'password file {!r} has group or world access; ' 'permissions should be u=rw (0600) or less'.format( passfile)) return [] with passfile.open('rt') as f: for line in f: line = line.strip() if not line or line.startswith('#'): # Skip empty lines and comments. continue # Backslash escapes both itself and the colon, # which is a record separator. line = line.replace(R'\\', '\n') passtab.append(tuple( p.replace('\n', R'\\') for p in re.split(r'(? 1: # If there is a list of ports, its length must # match that of the host list. if len(port) != len(hosts): raise exceptions.ClientConfigurationError( 'could not match {} port numbers to {} hosts'.format( len(port), len(hosts))) elif isinstance(port, list) and len(port) == 1: port = [port[0] for _ in range(len(hosts))] else: port = [port for _ in range(len(hosts))] return port def _parse_hostlist(hostlist, port, *, unquote=False): if ',' in hostlist: # A comma-separated list of host addresses. hostspecs = hostlist.split(',') else: hostspecs = [hostlist] hosts = [] hostlist_ports = [] if not port: portspec = os.environ.get('PGPORT') if portspec: if ',' in portspec: default_port = [int(p) for p in portspec.split(',')] else: default_port = int(portspec) else: default_port = 5432 default_port = _validate_port_spec(hostspecs, default_port) else: port = _validate_port_spec(hostspecs, port) for i, hostspec in enumerate(hostspecs): if hostspec[0] == '/': # Unix socket addr = hostspec hostspec_port = '' elif hostspec[0] == '[': # IPv6 address m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec) if m: addr = m.group(1) hostspec_port = m.group(2) else: raise exceptions.ClientConfigurationError( 'invalid IPv6 address in the connection URI: {!r}'.format( hostspec ) ) else: # IPv4 address addr, _, hostspec_port = hostspec.partition(':') if unquote: addr = urllib.parse.unquote(addr) hosts.append(addr) if not port: if hostspec_port: if unquote: hostspec_port = urllib.parse.unquote(hostspec_port) hostlist_ports.append(int(hostspec_port)) else: hostlist_ports.append(default_port[i]) if not port: port = hostlist_ports return hosts, port def _parse_tls_version(tls_version): if tls_version.startswith('SSL'): raise exceptions.ClientConfigurationError( f"Unsupported TLS version: {tls_version}" ) try: return ssl_module.TLSVersion[tls_version.replace('.', '_')] except KeyError: raise exceptions.ClientConfigurationError( f"No such TLS version: {tls_version}" ) def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: try: homedir = pathlib.Path.home() except (RuntimeError, KeyError): return None return (homedir / '.postgresql' / filename).resolve() def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, service, servicefile, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None ssl_min_protocol_version = ssl_max_protocol_version = None sslnegotiation = None if dsn: parsed = urllib.parse.urlparse(dsn) query = None if parsed.query: query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) for key, val in query.items(): if isinstance(val, list): query[key] = val[-1] if 'service' in query: val = query.pop('service') if not service and val: service = val connection_service_file = servicefile if connection_service_file is None: connection_service_file = os.getenv('PGSERVICEFILE') if connection_service_file is None: homedir = compat.get_pg_home_directory() if homedir: connection_service_file = homedir / PG_SERVICEFILE else: connection_service_file = None else: connection_service_file = pathlib.Path(connection_service_file) if parsed.scheme not in {'postgresql', 'postgres'}: raise exceptions.ClientConfigurationError( 'invalid DSN: scheme is expected to be either ' '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) if parsed.netloc: if '@' in parsed.netloc: dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@') else: dsn_hostspec = parsed.netloc dsn_auth = '' else: dsn_auth = dsn_hostspec = '' if dsn_auth: dsn_user, _, dsn_password = dsn_auth.partition(':') else: dsn_user = dsn_password = '' if not host and dsn_hostspec: host, port = _parse_hostlist(dsn_hostspec, port, unquote=True) if parsed.path and database is None: dsn_database = parsed.path if dsn_database.startswith('/'): dsn_database = dsn_database[1:] database = urllib.parse.unquote(dsn_database) if user is None and dsn_user: user = urllib.parse.unquote(dsn_user) if password is None and dsn_password: password = urllib.parse.unquote(dsn_password) if query: if 'port' in query: val = query.pop('port') if not port and val: port = [int(p) for p in val.split(',')] if 'host' in query: val = query.pop('host') if not host and val: host, port = _parse_hostlist(val, port) if 'dbname' in query: val = query.pop('dbname') if database is None: database = val if 'database' in query: val = query.pop('database') if database is None: database = val if 'user' in query: val = query.pop('user') if user is None: user = val if 'password' in query: val = query.pop('password') if password is None: password = val if 'passfile' in query: val = query.pop('passfile') if passfile is None: passfile = val if 'sslmode' in query: val = query.pop('sslmode') if ssl is None: ssl = val if 'sslcert' in query: sslcert = query.pop('sslcert') if 'sslkey' in query: sslkey = query.pop('sslkey') if 'sslrootcert' in query: sslrootcert = query.pop('sslrootcert') if 'sslnegotiation' in query: sslnegotiation = query.pop('sslnegotiation') if 'sslcrl' in query: sslcrl = query.pop('sslcrl') if 'sslpassword' in query: sslpassword = query.pop('sslpassword') if 'ssl_min_protocol_version' in query: ssl_min_protocol_version = query.pop( 'ssl_min_protocol_version' ) if 'ssl_max_protocol_version' in query: ssl_max_protocol_version = query.pop( 'ssl_max_protocol_version' ) if 'target_session_attrs' in query: dsn_target_session_attrs = query.pop( 'target_session_attrs' ) if target_session_attrs is None: target_session_attrs = dsn_target_session_attrs if 'krbsrvname' in query: val = query.pop('krbsrvname') if krbsrvname is None: krbsrvname = val if 'gsslib' in query: val = query.pop('gsslib') if gsslib is None: gsslib = val if 'service' in query: val = query.pop('service') if service is None: service = val if query: if server_settings is None: server_settings = query else: server_settings = {**query, **server_settings} if connection_service_file is not None and service is not None: pg_service = configparser.ConfigParser() pg_service.read(connection_service_file) if service in pg_service.sections(): service_params = pg_service[service] if 'port' in service_params: val = service_params.pop('port') if not port and val: port = [int(p) for p in val.split(',')] if 'host' in service_params: val = service_params.pop('host') if not host and val: host, port = _parse_hostlist(val, port) if 'dbname' in service_params: val = service_params.pop('dbname') if database is None: database = val if 'database' in service_params: val = service_params.pop('database') if database is None: database = val if 'user' in service_params: val = service_params.pop('user') if user is None: user = val if 'password' in service_params: val = service_params.pop('password') if password is None: password = val if 'passfile' in service_params: val = service_params.pop('passfile') if passfile is None: passfile = val if 'sslmode' in service_params: val = service_params.pop('sslmode') if ssl is None: ssl = val if 'sslcert' in service_params: val = service_params.pop('sslcert') if sslcert is None: sslcert = val if 'sslkey' in service_params: val = service_params.pop('sslkey') if sslkey is None: sslkey = val if 'sslrootcert' in service_params: val = service_params.pop('sslrootcert') if sslrootcert is None: sslrootcert = val if 'sslnegotiation' in service_params: val = service_params.pop('sslnegotiation') if sslnegotiation is None: sslnegotiation = val if 'sslcrl' in service_params: val = service_params.pop('sslcrl') if sslcrl is None: sslcrl = val if 'sslpassword' in service_params: val = service_params.pop('sslpassword') if sslpassword is None: sslpassword = val if 'ssl_min_protocol_version' in service_params: val = service_params.pop( 'ssl_min_protocol_version' ) if ssl_min_protocol_version is None: ssl_min_protocol_version = val if 'ssl_max_protocol_version' in service_params: val = service_params.pop( 'ssl_max_protocol_version' ) if ssl_max_protocol_version is None: ssl_max_protocol_version = val if 'target_session_attrs' in service_params: dsn_target_session_attrs = service_params.pop( 'target_session_attrs' ) if target_session_attrs is None: target_session_attrs = dsn_target_session_attrs if 'krbsrvname' in service_params: val = service_params.pop('krbsrvname') if krbsrvname is None: krbsrvname = val if 'gsslib' in service_params: val = service_params.pop('gsslib') if gsslib is None: gsslib = val if not service: service = os.environ.get('PGSERVICE') if not host: hostspec = os.environ.get('PGHOST') if hostspec: host, port = _parse_hostlist(hostspec, port) if not host: auth_hosts = ['localhost'] if _system == 'Windows': host = ['localhost'] else: host = ['/run/postgresql', '/var/run/postgresql', '/tmp', '/private/tmp', 'localhost'] if not isinstance(host, (list, tuple)): host = [host] if auth_hosts is None: auth_hosts = host if not port: portspec = os.environ.get('PGPORT') if portspec: if ',' in portspec: port = [int(p) for p in portspec.split(',')] else: port = int(portspec) else: port = 5432 elif isinstance(port, (list, tuple)): port = [int(p) for p in port] else: port = int(port) port = _validate_port_spec(host, port) if user is None: user = os.getenv('PGUSER') if not user: user = getpass.getuser() if password is None: password = os.getenv('PGPASSWORD') if database is None: database = os.getenv('PGDATABASE') if database is None: database = user if user is None: raise exceptions.ClientConfigurationError( 'could not determine user name to connect with') if database is None: raise exceptions.ClientConfigurationError( 'could not determine database name to connect to') if password is None: if passfile is None: passfile = os.getenv('PGPASSFILE') if passfile is None: homedir = compat.get_pg_home_directory() if homedir: passfile = homedir / PGPASSFILE else: passfile = None else: passfile = pathlib.Path(passfile) if passfile is not None: password = _read_password_from_pgpass( hosts=auth_hosts, ports=port, database=database, user=user, passfile=passfile) addrs = [] have_tcp_addrs = False for h, p in zip(host, port): if h.startswith('/'): # UNIX socket name if '.s.PGSQL.' not in h: h = os.path.join(h, '.s.PGSQL.{}'.format(p)) addrs.append(h) else: # TCP host/port addrs.append((h, p)) have_tcp_addrs = True if not addrs: raise exceptions.InternalClientError( 'could not determine the database address to connect to') if ssl is None: ssl = os.getenv('PGSSLMODE') if ssl is None and have_tcp_addrs: ssl = 'prefer' if direct_tls is not None: sslneg = ( SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres ) else: if sslnegotiation is None: sslnegotiation = os.environ.get("PGSSLNEGOTIATION") if sslnegotiation is not None: try: sslneg = SSLNegotiation(sslnegotiation) except ValueError: modes = ', '.join( m.name.replace('_', '-') for m in SSLNegotiation ) raise exceptions.ClientConfigurationError( f'`sslnegotiation` parameter must be one of: {modes}' ) from None else: sslneg = SSLNegotiation.postgres if isinstance(ssl, (str, SSLMode)): try: sslmode = SSLMode.parse(ssl) except AttributeError: modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) raise exceptions.ClientConfigurationError( '`sslmode` parameter must be one of: {}'.format(modes) ) from None # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html if sslmode < SSLMode.allow: ssl = False else: ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) ssl.check_hostname = sslmode >= SSLMode.verify_full if sslmode < SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE else: if sslrootcert is None: sslrootcert = os.getenv('PGSSLROOTCERT') if sslrootcert: ssl.load_verify_locations(cafile=sslrootcert) ssl.verify_mode = ssl_module.CERT_REQUIRED else: try: sslrootcert = _dot_postgresql_path('root.crt') if sslrootcert is not None: ssl.load_verify_locations(cafile=sslrootcert) else: raise exceptions.ClientConfigurationError( 'cannot determine location of user ' 'PostgreSQL configuration directory' ) except ( exceptions.ClientConfigurationError, FileNotFoundError, NotADirectoryError, ): if sslmode > SSLMode.require: if sslrootcert is None: sslrootcert = '~/.postgresql/root.crt' detail = ( 'Could not determine location of user ' 'home directory (HOME is either unset, ' 'inaccessible, or does not point to a ' 'valid directory)' ) else: detail = None raise exceptions.ClientConfigurationError( f'root certificate file "{sslrootcert}" does ' f'not exist or cannot be accessed', hint='Provide the certificate file directly ' f'or make sure "{sslrootcert}" ' 'exists and is readable.', detail=detail, ) elif sslmode == SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE else: assert False, 'unreachable' else: ssl.verify_mode = ssl_module.CERT_REQUIRED if sslcrl is None: sslcrl = os.getenv('PGSSLCRL') if sslcrl: ssl.load_verify_locations(cafile=sslcrl) ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN else: sslcrl = _dot_postgresql_path('root.crl') if sslcrl is not None: try: ssl.load_verify_locations(cafile=sslcrl) except ( FileNotFoundError, NotADirectoryError, ): pass else: ssl.verify_flags |= \ ssl_module.VERIFY_CRL_CHECK_CHAIN if sslkey is None: sslkey = os.getenv('PGSSLKEY') if not sslkey: sslkey = _dot_postgresql_path('postgresql.key') if sslkey is not None and not sslkey.exists(): sslkey = None if not sslpassword: sslpassword = '' if sslcert is None: sslcert = os.getenv('PGSSLCERT') if sslcert: ssl.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword ) else: sslcert = _dot_postgresql_path('postgresql.crt') if sslcert is not None: try: ssl.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword ) except (FileNotFoundError, NotADirectoryError): pass # OpenSSL 1.1.1 keylog file, copied from create_default_context() if hasattr(ssl, 'keylog_filename'): keylogfile = os.environ.get('SSLKEYLOGFILE') if keylogfile and not sys.flags.ignore_environment: ssl.keylog_filename = keylogfile if ssl_min_protocol_version is None: ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') if ssl_min_protocol_version: ssl.minimum_version = _parse_tls_version( ssl_min_protocol_version ) else: ssl.minimum_version = _parse_tls_version('TLSv1.2') if ssl_max_protocol_version is None: ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') if ssl_max_protocol_version: ssl.maximum_version = _parse_tls_version( ssl_max_protocol_version ) elif ssl is True: ssl = ssl_module.create_default_context() sslmode = SSLMode.verify_full else: sslmode = SSLMode.disable if server_settings is not None and ( not isinstance(server_settings, dict) or not all(isinstance(k, str) for k in server_settings) or not all(isinstance(v, str) for v in server_settings.values())): raise exceptions.ClientConfigurationError( 'server_settings is expected to be None or ' 'a Dict[str, str]') if target_session_attrs is None: target_session_attrs = os.getenv( "PGTARGETSESSIONATTRS", SessionAttribute.any ) try: target_session_attrs = SessionAttribute(target_session_attrs) except ValueError: raise exceptions.ClientConfigurationError( "target_session_attrs is expected to be one of " "{!r}" ", got {!r}".format( SessionAttribute.__members__.values, target_session_attrs ) ) from None if krbsrvname is None: krbsrvname = os.getenv('PGKRBSRVNAME') if gsslib is None: gsslib = os.getenv('PGGSSLIB') if gsslib is None: gsslib = 'sspi' if _system == 'Windows' else 'gssapi' if gsslib not in {'gssapi', 'sspi'}: raise exceptions.ClientConfigurationError( "gsslib parameter must be either 'gssapi' or 'sspi'" ", got {!r}".format(gsslib)) params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, ssl_negotiation=sslneg, server_settings=server_settings, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib) return addrs, params def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, database, command_timeout, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib, service, servicefile): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', 'statement_cache_size'}: var_val = local_vars[var_name] if var_val is None or isinstance(var_val, bool) or var_val < 0: raise ValueError( '{} is expected to be greater ' 'or equal to 0, got {!r}'.format(var_name, var_val)) if command_timeout is not None: try: if isinstance(command_timeout, bool): raise ValueError command_timeout = float(command_timeout) if command_timeout <= 0: raise ValueError except ValueError: raise ValueError( 'invalid command_timeout value: ' 'expected greater than 0 float (got {!r})'.format( command_timeout)) from None addrs, params = _parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, server_settings=server_settings, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib, service=service, servicefile=servicefile) config = _ClientConfiguration( command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size,) return addrs, params, config class TLSUpgradeProto(asyncio.Protocol): def __init__( self, loop: asyncio.AbstractEventLoop, host: str, port: int, ssl_context: ssl_module.SSLContext, ssl_is_advisory: bool, ) -> None: self.on_data = _create_future(loop) self.host = host self.port = port self.ssl_context = ssl_context self.ssl_is_advisory = ssl_is_advisory def data_received(self, data: bytes) -> None: if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and self.ssl_context.verify_mode == ssl_module.CERT_NONE and data == b'N'): # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, # since the only way to get ssl_is_advisory is from # sslmode=prefer. But be extra sure to disallow insecure # connections when the ssl context asks for real security. self.on_data.set_result(False) else: self.on_data.set_exception( ConnectionError( 'PostgreSQL server at "{host}:{port}" ' 'rejected SSL upgrade'.format( host=self.host, port=self.port))) def connection_lost(self, exc: typing.Optional[Exception]) -> None: if not self.on_data.done(): if exc is None: exc = ConnectionError('unexpected connection_lost() call') self.on_data.set_exception(exc) _ProctolFactoryR = typing.TypeVar( "_ProctolFactoryR", bound=asyncio.protocols.Protocol ) async def _create_ssl_connection( # TODO: The return type is a specific combination of subclasses of # asyncio.protocols.Protocol that we can't express. For now, having the # return type be dependent on signature of the factory is an improvement protocol_factory: Callable[[], _ProctolFactoryR], host: str, port: int, *, loop: asyncio.AbstractEventLoop, ssl_context: ssl_module.SSLContext, ssl_is_advisory: bool = False, ) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]: tr, pr = await loop.create_connection( lambda: TLSUpgradeProto(loop, host, port, ssl_context, ssl_is_advisory), host, port) tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. try: do_ssl_upgrade = await pr.on_data except (Exception, asyncio.CancelledError): tr.close() raise if hasattr(loop, 'start_tls'): if do_ssl_upgrade: try: new_tr = await loop.start_tls( tr, pr, ssl_context, server_hostname=host) assert new_tr is not None except (Exception, asyncio.CancelledError): tr.close() raise else: new_tr = tr pg_proto = protocol_factory() pg_proto.is_ssl = do_ssl_upgrade pg_proto.connection_made(new_tr) new_tr.set_protocol(pg_proto) return new_tr, pg_proto else: conn_factory = functools.partial( loop.create_connection, protocol_factory) if do_ssl_upgrade: conn_factory = functools.partial( conn_factory, ssl=ssl_context, server_hostname=host) sock = _get_socket(tr) sock = sock.dup() _set_nodelay(sock) tr.close() try: new_tr, pg_proto = await conn_factory(sock=sock) pg_proto.is_ssl = do_ssl_upgrade return new_tr, pg_proto except (Exception, asyncio.CancelledError): sock.close() raise async def _connect_addr( *, addr, loop, params, config, connection_class, record_class ): assert loop is not None params_input = params if callable(params.password): password = params.password() if inspect.isawaitable(password): password = await password params = params._replace(password=password) args = (addr, loop, config, connection_class, record_class, params_input) # prepare the params (which attempt has ssl) for the 2 attempts if params.sslmode == SSLMode.allow: params_retry = params params = params._replace(ssl=None) elif params.sslmode == SSLMode.prefer: params_retry = params._replace(ssl=None) else: # skip retry if we don't have to return await __connect_addr(params, False, *args) # first attempt try: return await __connect_addr(params, True, *args) except _RetryConnectSignal: pass # second attempt return await __connect_addr(params_retry, False, *args) class _RetryConnectSignal(Exception): pass async def __connect_addr( params, retry, addr, loop, config, connection_class, record_class, params_input, ): connected = _create_future(loop) proto_factory = lambda: protocol.Protocol( addr, connected, params, record_class, loop) if isinstance(addr, str): # UNIX socket connector = loop.create_unix_connection(proto_factory, addr) elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct: # if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform # direct SSL connection connector = loop.create_connection( proto_factory, *addr, ssl=params.ssl ) elif params.ssl: connector = _create_ssl_connection( proto_factory, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: connector = loop.create_connection(proto_factory, *addr) tr, pr = await connector try: await connected except ( exceptions.InvalidAuthorizationSpecificationError, exceptions.ConnectionDoesNotExistError, # seen on Windows ): tr.close() # retry=True here is a redundant check because we don't want to # accidentally raise the internal _RetryConnectSignal to the user if retry and ( params.sslmode == SSLMode.allow and not pr.is_ssl or params.sslmode == SSLMode.prefer and pr.is_ssl ): # Trigger retry when: # 1. First attempt with sslmode=allow, ssl=None failed # 2. First attempt with sslmode=prefer, ssl=ctx failed while the # server claimed to support SSL (returning "S" for SSLRequest) # (likely because pg_hba.conf rejected the connection) raise _RetryConnectSignal() else: # but will NOT retry if: # 1. First attempt with sslmode=prefer failed but the server # doesn't support SSL (returning 'N' for SSLRequest), because # we already tried to connect without SSL thru ssl_is_advisory # 2. Second attempt with sslmode=prefer, ssl=None failed # 3. Second attempt with sslmode=allow, ssl=ctx failed # 4. Any other sslmode raise except (Exception, asyncio.CancelledError): tr.close() raise con = connection_class(pr, tr, loop, addr, config, params_input) pr.set_connection(con) return con class SessionAttribute(str, enum.Enum): any = 'any' primary = 'primary' standby = 'standby' prefer_standby = 'prefer-standby' read_write = "read-write" read_only = "read-only" def _accept_in_hot_standby(should_be_in_hot_standby: bool): """ If the server didn't report "in_hot_standby" at startup, we must determine the state by checking "SELECT pg_catalog.pg_is_in_recovery()". If the server allows a connection and states it is in recovery it must be a replica/standby server. """ async def can_be_used(connection): settings = connection.get_settings() hot_standby_status = getattr(settings, 'in_hot_standby', None) if hot_standby_status is not None: is_in_hot_standby = hot_standby_status == 'on' else: is_in_hot_standby = await connection.fetchval( "SELECT pg_catalog.pg_is_in_recovery()" ) return is_in_hot_standby == should_be_in_hot_standby return can_be_used def _accept_read_only(should_be_read_only: bool): """ Verify the server has not set default_transaction_read_only=True """ async def can_be_used(connection): settings = connection.get_settings() is_readonly = getattr(settings, 'default_transaction_read_only', 'off') if is_readonly == "on": return should_be_read_only return await _accept_in_hot_standby(should_be_read_only)(connection) return can_be_used async def _accept_any(_): return True target_attrs_check = { SessionAttribute.any: _accept_any, SessionAttribute.primary: _accept_in_hot_standby(False), SessionAttribute.standby: _accept_in_hot_standby(True), SessionAttribute.prefer_standby: _accept_in_hot_standby(True), SessionAttribute.read_write: _accept_read_only(False), SessionAttribute.read_only: _accept_read_only(True), } async def _can_use_connection(connection, attr: SessionAttribute): can_use = target_attrs_check[attr] return await can_use(connection) async def _connect(*, loop, connection_class, record_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(**kwargs) target_attr = params.target_session_attrs candidates = [] chosen_connection = None last_error = None try: for addr in addrs: try: conn = await _connect_addr( addr=addr, loop=loop, params=params, config=config, connection_class=connection_class, record_class=record_class, ) candidates.append(conn) if await _can_use_connection(conn, target_attr): chosen_connection = conn break except OSError as ex: last_error = ex else: if target_attr == SessionAttribute.prefer_standby and candidates: chosen_connection = random.choice(candidates) finally: async def _close_candidates(conns, chosen): await asyncio.gather( *(c.close() for c in conns if c is not chosen), return_exceptions=True ) if candidates: asyncio.create_task( _close_candidates(candidates, chosen_connection)) if chosen_connection: return chosen_connection raise last_error or exceptions.TargetServerAttributeNotMatched( 'None of the hosts match the target attribute requirement ' '{!r}'.format(target_attr) ) async def _cancel(*, loop, addr, params: _ConnectionParameters, backend_pid, backend_secret): class CancelProto(asyncio.Protocol): def __init__(self): self.on_disconnect = _create_future(loop) self.is_ssl = False def connection_lost(self, exc): if not self.on_disconnect.done(): self.on_disconnect.set_result(True) if isinstance(addr, str): tr, pr = await loop.create_unix_connection(CancelProto, addr) else: if params.ssl and params.sslmode != SSLMode.allow: tr, pr = await _create_ssl_connection( CancelProto, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: tr, pr = await loop.create_connection( CancelProto, *addr) _set_nodelay(_get_socket(tr)) # Pack a CancelRequest message msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret) try: tr.write(msg) await pr.on_disconnect finally: tr.close() def _get_socket(transport): sock = transport.get_extra_info('socket') if sock is None: # Shouldn't happen with any asyncio-complaint event loop. raise ConnectionError( 'could not get the socket for transport {!r}'.format(transport)) return sock def _set_nodelay(sock): if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) def _create_future(loop): try: create_future = loop.create_future except AttributeError: return asyncio.Future(loop=loop) else: return create_future() ================================================ FILE: asyncpg/connection.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import asyncpg import collections import collections.abc import contextlib import functools import itertools import inspect import os import sys import time import traceback import typing import warnings import weakref from . import compat from . import connect_utils from . import cursor from . import exceptions from . import introspection from . import prepared_stmt from . import protocol from . import serverversion from . import transaction from . import utils class ConnectionMeta(type): def __instancecheck__(cls, instance): mro = type(instance).__mro__ return Connection in mro or _ConnectionProxy in mro class Connection(metaclass=ConnectionMeta): """A representation of a database session. Connections are created by calling :func:`~asyncpg.connection.connect`. """ __slots__ = ('_protocol', '_transport', '_loop', '_top_xact', '_aborted', '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', '_stmt_cache_enabled', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', '_log_listeners', '_termination_listeners', '_cancellations', '_source_traceback', '_query_loggers', '__weakref__') def __init__(self, protocol, transport, loop, addr, config: connect_utils._ClientConfiguration, params: connect_utils._ConnectionParameters): self._protocol = protocol self._transport = transport self._loop = loop self._top_xact = None self._aborted = False # Incremented every time the connection is released back to a pool. # Used to catch invalid references to connection-related resources # post-release (e.g. explicit prepared statements). self._pool_release_ctr = 0 self._addr = addr self._config = config self._params = params self._stmt_cache = _StatementCache( loop=loop, max_size=config.statement_cache_size, on_remove=functools.partial( _weak_maybe_gc_stmt, weakref.ref(self)), max_lifetime=config.max_cached_statement_lifetime) self._stmts_to_close = set() self._stmt_cache_enabled = config.statement_cache_size > 0 self._listeners = {} self._log_listeners = set() self._cancellations = set() self._termination_listeners = set() self._query_loggers = set() settings = self._protocol.get_settings() ver_string = settings.server_version self._server_version = \ serverversion.split_server_version_string(ver_string) self._server_caps = _detect_server_capabilities( self._server_version, settings) if self._server_version < (14, 0): self._intro_query = introspection.INTRO_LOOKUP_TYPES_13 else: self._intro_query = introspection.INTRO_LOOKUP_TYPES self._reset_query = None self._proxy = None # Used to serialize operations that might involve anonymous # statements. Specifically, we want to make the following # operation atomic: # ("prepare an anonymous statement", "use the statement") # # Used for `con.fetchval()`, `con.fetch()`, `con.fetchrow()`, # `con.execute()`, and `con.executemany()`. self._stmt_exclusive_section = _Atomic() if loop.get_debug(): self._source_traceback = _extract_stack() else: self._source_traceback = None def __del__(self): if not self.is_closed() and self._protocol is not None: if self._source_traceback: msg = "unclosed connection {!r}; created at:\n {}".format( self, self._source_traceback) else: msg = ( "unclosed connection {!r}; run in asyncio debug " "mode to show the traceback of connection " "origin".format(self) ) warnings.warn(msg, ResourceWarning) if not self._loop.is_closed(): self.terminate() async def add_listener(self, channel, callback): """Add a listener for Postgres notifications. :param str channel: Channel to listen on. :param callable callback: A callable or a coroutine function receiving the following arguments: **connection**: a Connection the callback is registered with; **pid**: PID of the Postgres server that sent the notification; **channel**: name of the channel the notification was sent to; **payload**: the payload. .. versionchanged:: 0.24.0 The ``callback`` argument may be a coroutine function. """ self._check_open() if channel not in self._listeners: await self.fetch('LISTEN {}'.format(utils._quote_ident(channel))) self._listeners[channel] = set() self._listeners[channel].add(_Callback.from_callable(callback)) async def remove_listener(self, channel, callback): """Remove a listening callback on the specified channel.""" if self.is_closed(): return if channel not in self._listeners: return cb = _Callback.from_callable(callback) if cb not in self._listeners[channel]: return self._listeners[channel].remove(cb) if not self._listeners[channel]: del self._listeners[channel] await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel))) def add_log_listener(self, callback): """Add a listener for Postgres log messages. It will be called when asyncronous NoticeResponse is received from the connection. Possible message types are: WARNING, NOTICE, DEBUG, INFO, or LOG. :param callable callback: A callable or a coroutine function receiving the following arguments: **connection**: a Connection the callback is registered with; **message**: the `exceptions.PostgresLogMessage` message. .. versionadded:: 0.12.0 .. versionchanged:: 0.24.0 The ``callback`` argument may be a coroutine function. """ if self.is_closed(): raise exceptions.InterfaceError('connection is closed') self._log_listeners.add(_Callback.from_callable(callback)) def remove_log_listener(self, callback): """Remove a listening callback for log messages. .. versionadded:: 0.12.0 """ self._log_listeners.discard(_Callback.from_callable(callback)) def add_termination_listener(self, callback): """Add a listener that will be called when the connection is closed. :param callable callback: A callable or a coroutine function receiving one argument: **connection**: a Connection the callback is registered with. .. versionadded:: 0.21.0 .. versionchanged:: 0.24.0 The ``callback`` argument may be a coroutine function. """ self._termination_listeners.add(_Callback.from_callable(callback)) def remove_termination_listener(self, callback): """Remove a listening callback for connection termination. :param callable callback: The callable or coroutine function that was passed to :meth:`Connection.add_termination_listener`. .. versionadded:: 0.21.0 """ self._termination_listeners.discard(_Callback.from_callable(callback)) def add_query_logger(self, callback): """Add a logger that will be called when queries are executed. :param callable callback: A callable or a coroutine function receiving one argument: **record**, a LoggedQuery containing `query`, `args`, `timeout`, `elapsed`, `exception`, `conn_addr`, and `conn_params`. .. versionadded:: 0.29.0 """ self._query_loggers.add(_Callback.from_callable(callback)) def remove_query_logger(self, callback): """Remove a query logger callback. :param callable callback: The callable or coroutine function that was passed to :meth:`Connection.add_query_logger`. .. versionadded:: 0.29.0 """ self._query_loggers.discard(_Callback.from_callable(callback)) def get_server_pid(self): """Return the PID of the Postgres server the connection is bound to.""" return self._protocol.get_server_pid() def get_server_version(self): """Return the version of the connected PostgreSQL server. The returned value is a named tuple similar to that in ``sys.version_info``: .. code-block:: pycon >>> con.get_server_version() ServerVersion(major=9, minor=6, micro=1, releaselevel='final', serial=0) .. versionadded:: 0.8.0 """ return self._server_version def get_settings(self): """Return connection settings. :return: :class:`~asyncpg.ConnectionSettings`. """ return self._protocol.get_settings() def transaction(self, *, isolation=None, readonly=False, deferrable=False): """Create a :class:`~transaction.Transaction` object. Refer to `PostgreSQL documentation`_ on the meaning of transaction parameters. :param isolation: Transaction isolation mode, can be one of: `'serializable'`, `'repeatable_read'`, `'read_uncommitted'`, `'read_committed'`. If not specified, the behavior is up to the server and session, which is usually ``read_committed``. :param readonly: Specifies whether or not this transaction is read-only. :param deferrable: Specifies whether or not this transaction is deferrable. .. _`PostgreSQL documentation`: https://www.postgresql.org/docs/ current/static/sql-set-transaction.html """ self._check_open() return transaction.Transaction(self, isolation, readonly, deferrable) def is_in_transaction(self): """Return True if Connection is currently inside a transaction. :return bool: True if inside transaction, False otherwise. .. versionadded:: 0.16.0 """ return self._protocol.is_in_transaction() async def execute( self, query: str, *args, timeout: typing.Optional[float]=None, ) -> str: """Execute an SQL command (or commands). This method can execute many SQL commands at once, when no arguments are provided. Example: .. code-block:: pycon >>> await con.execute(''' ... CREATE TABLE mytab (a int); ... INSERT INTO mytab (a) VALUES (100), (200), (300); ... ''') INSERT 0 3 >>> await con.execute(''' ... INSERT INTO mytab (a) VALUES ($1), ($2) ... ''', 10, 20) INSERT 0 2 :param args: Query arguments. :param float timeout: Optional timeout value in seconds. :return str: Status of the last SQL command. .. versionchanged:: 0.5.4 Made it possible to pass query arguments. """ self._check_open() if not args: if self._query_loggers: with self._time_and_log(query, args, timeout): result = await self._protocol.query(query, timeout) else: result = await self._protocol.query(query, timeout) return result _, status, _ = await self._execute( query, args, 0, timeout, return_status=True, ) return status.decode() async def executemany( self, command: str, args, *, timeout: typing.Optional[float]=None, ): """Execute an SQL *command* for each sequence of arguments in *args*. Example: .. code-block:: pycon >>> await con.executemany(''' ... INSERT INTO mytab (a) VALUES ($1, $2, $3); ... ''', [(1, 2, 3), (4, 5, 6)]) :param command: Command to execute. :param args: An iterable containing sequences of arguments. :param float timeout: Optional timeout value in seconds. :return None: This method discards the results of the operations. .. versionadded:: 0.7.0 .. versionchanged:: 0.11.0 `timeout` became a keyword-only parameter. .. versionchanged:: 0.22.0 ``executemany()`` is now an atomic operation, which means that either all executions succeed, or none at all. This is in contrast to prior versions, where the effect of already-processed iterations would remain in place when an error has occurred, unless ``executemany()`` was called in a transaction. """ self._check_open() return await self._executemany(command, args, timeout) async def _get_statement( self, query, timeout, *, named: typing.Union[str, bool, None] = False, use_cache=True, ignore_custom_codec=False, record_class=None ): if record_class is None: record_class = self._protocol.get_record_class() else: _check_record_class(record_class) if use_cache: statement = self._stmt_cache.get( (query, record_class, ignore_custom_codec) ) if statement is not None: return statement # Only use the cache when: # * `statement_cache_size` is greater than 0; # * query size is less than `max_cacheable_statement_size`. use_cache = ( self._stmt_cache_enabled and ( not self._config.max_cacheable_statement_size or len(query) <= self._config.max_cacheable_statement_size ) ) if isinstance(named, str): stmt_name = named elif use_cache or named: stmt_name = self._get_unique_id('stmt') else: stmt_name = '' statement = await self._protocol.prepare( stmt_name, query, timeout, record_class=record_class, ignore_custom_codec=ignore_custom_codec, ) need_reprepare = False types_with_missing_codecs = statement._init_types() tries = 0 while types_with_missing_codecs: settings = self._protocol.get_settings() # Introspect newly seen types and populate the # codec cache. types, intro_stmt = await self._introspect_types( types_with_missing_codecs, timeout) settings.register_data_types(types) # The introspection query has used an anonymous statement, # which has blown away the anonymous statement we've prepared # for the query, so we need to re-prepare it. need_reprepare = not intro_stmt.name and not statement.name types_with_missing_codecs = statement._init_types() tries += 1 if tries > 5: # In the vast majority of cases there will be only # one iteration. In rare cases, there might be a race # with reload_schema_state(), which would cause a # second try. More than five is clearly a bug. raise exceptions.InternalClientError( 'could not resolve query result and/or argument types ' 'in {} attempts'.format(tries) ) # Now that types have been resolved, populate the codec pipeline # for the statement. statement._init_codecs() if ( need_reprepare or (not statement.name and not self._stmt_cache_enabled) ): # Mark this anonymous prepared statement as "unprepared", # causing it to get re-Parsed in next bind_execute. # We always do this when stmt_cache_size is set to 0 assuming # people are running PgBouncer which is mishandling implicit # transactions. statement.mark_unprepared() if use_cache: self._stmt_cache.put( (query, record_class, ignore_custom_codec), statement) # If we've just created a new statement object, check if there # are any statements for GC. if self._stmts_to_close: await self._cleanup_stmts() return statement async def _introspect_types(self, typeoids, timeout): if self._server_caps.jit: try: cfgrow, _ = await self.__execute( """ SELECT current_setting('jit') AS cur, set_config('jit', 'off', false) AS new """, (), 0, timeout, ignore_custom_codec=True, ) jit_state = cfgrow[0]['cur'] except exceptions.UndefinedObjectError: jit_state = 'off' else: jit_state = 'off' result = await self.__execute( self._intro_query, (list(typeoids),), 0, timeout, ignore_custom_codec=True, ) if jit_state != 'off': await self.__execute( """ SELECT set_config('jit', $1, false) """, (jit_state,), 0, timeout, ignore_custom_codec=True, ) return result async def _introspect_type(self, typename, schema): if schema == 'pg_catalog' and not typename.endswith("[]"): typeoid = protocol.BUILTIN_TYPE_NAME_MAP.get(typename.lower()) if typeoid is not None: return introspection.TypeRecord((typeoid, None, b"b")) rows = await self._execute( introspection.TYPE_BY_NAME, [typename, schema], limit=1, timeout=None, ignore_custom_codec=True, ) if not rows: raise ValueError( 'unknown type: {}.{}'.format(schema, typename)) return rows[0] def cursor( self, query, *args, prefetch=None, timeout=None, record_class=None ): """Return a *cursor factory* for the specified query. :param args: Query arguments. :param int prefetch: The number of rows the *cursor iterator* will prefetch (defaults to ``50``.) :param float timeout: Optional timeout in seconds. :param type record_class: If specified, the class to use for records returned by this cursor. Must be a subclass of :class:`~asyncpg.Record`. If not specified, a per-connection *record_class* is used. :return: A :class:`~cursor.CursorFactory` object. .. versionchanged:: 0.22.0 Added the *record_class* parameter. """ self._check_open() return cursor.CursorFactory( self, query, None, args, prefetch, timeout, record_class, ) async def prepare( self, query, *, name=None, timeout=None, record_class=None, ): """Create a *prepared statement* for the specified query. :param str query: Text of the query to create a prepared statement for. :param str name: Optional name of the returned prepared statement. If not specified, the name is auto-generated. :param float timeout: Optional timeout value in seconds. :param type record_class: If specified, the class to use for records returned by the prepared statement. Must be a subclass of :class:`~asyncpg.Record`. If not specified, a per-connection *record_class* is used. :return: A :class:`~prepared_stmt.PreparedStatement` instance. .. versionchanged:: 0.22.0 Added the *record_class* parameter. .. versionchanged:: 0.25.0 Added the *name* parameter. """ return await self._prepare( query, name=name, timeout=timeout, record_class=record_class, ) async def _prepare( self, query, *, name: typing.Union[str, bool, None] = None, timeout=None, use_cache: bool=False, record_class=None ): self._check_open() if name is None: name = self._stmt_cache_enabled stmt = await self._get_statement( query, timeout, named=name, use_cache=use_cache, record_class=record_class, ) return prepared_stmt.PreparedStatement(self, query, stmt) async def fetch( self, query, *args, timeout=None, record_class=None ) -> list: """Run a query and return the results as a list of :class:`Record`. :param str query: Query text. :param args: Query arguments. :param float timeout: Optional timeout value in seconds. :param type record_class: If specified, the class to use for records returned by this method. Must be a subclass of :class:`~asyncpg.Record`. If not specified, a per-connection *record_class* is used. :return list: A list of :class:`~asyncpg.Record` instances. If specified, the actual type of list elements would be *record_class*. .. versionchanged:: 0.22.0 Added the *record_class* parameter. """ self._check_open() return await self._execute( query, args, 0, timeout, record_class=record_class, ) async def fetchval(self, query, *args, column=0, timeout=None): """Run a query and return a value in the first row. :param str query: Query text. :param args: Query arguments. :param int column: Numeric index within the record of the value to return (defaults to 0). :param float timeout: Optional timeout value in seconds. If not specified, defaults to the value of ``command_timeout`` argument to the ``Connection`` instance constructor. :return: The value of the specified column of the first record, or None if no records were returned by the query. """ self._check_open() data = await self._execute(query, args, 1, timeout) if not data: return None return data[0][column] async def fetchrow( self, query, *args, timeout=None, record_class=None ): """Run a query and return the first row. :param str query: Query text :param args: Query arguments :param float timeout: Optional timeout value in seconds. :param type record_class: If specified, the class to use for the value returned by this method. Must be a subclass of :class:`~asyncpg.Record`. If not specified, a per-connection *record_class* is used. :return: The first row as a :class:`~asyncpg.Record` instance, or None if no records were returned by the query. If specified, *record_class* is used as the type for the result value. .. versionchanged:: 0.22.0 Added the *record_class* parameter. """ self._check_open() data = await self._execute( query, args, 1, timeout, record_class=record_class, ) if not data: return None return data[0] async def fetchmany( self, query, args, *, timeout: typing.Optional[float]=None, record_class=None, ): """Run a query for each sequence of arguments in *args* and return the results as a list of :class:`Record`. :param query: Query to execute. :param args: An iterable containing sequences of arguments for the query. :param float timeout: Optional timeout value in seconds. :param type record_class: If specified, the class to use for records returned by this method. Must be a subclass of :class:`~asyncpg.Record`. If not specified, a per-connection *record_class* is used. :return list: A list of :class:`~asyncpg.Record` instances. If specified, the actual type of list elements would be *record_class*. Example: .. code-block:: pycon >>> rows = await con.fetchmany(''' ... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a; ... ''', [('x', 1), ('y', 2), ('z', 3)]) >>> rows [, , ] .. versionadded:: 0.30.0 """ self._check_open() return await self._executemany( query, args, timeout, return_rows=True, record_class=record_class ) async def copy_from_table(self, table_name, *, output, columns=None, schema_name=None, timeout=None, format=None, oids=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, encoding=None): """Copy table contents to a file or file-like object. :param str table_name: The name of the table to copy data from. :param output: A :term:`path-like object `, or a :term:`file-like object `, or a :term:`coroutine function ` that takes a ``bytes`` instance as a sole argument. :param list columns: An optional list of column names to copy. :param str schema_name: An optional schema name to qualify the table. :param float timeout: Optional timeout value in seconds. The remaining keyword arguments are ``COPY`` statement options, see `COPY statement documentation`_ for details. :return: The status string of the COPY command. Example: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> async def run(): ... con = await asyncpg.connect(user='postgres') ... result = await con.copy_from_table( ... 'mytable', columns=('foo', 'bar'), ... output='file.csv', format='csv') ... print(result) ... >>> asyncio.run(run()) 'COPY 100' .. _`COPY statement documentation`: https://www.postgresql.org/docs/current/static/sql-copy.html .. versionadded:: 0.11.0 """ tabname = utils._quote_ident(table_name) if schema_name: tabname = utils._quote_ident(schema_name) + '.' + tabname if columns: cols = '({})'.format( ', '.join(utils._quote_ident(c) for c in columns)) else: cols = '' opts = self._format_copy_opts( format=format, oids=oids, delimiter=delimiter, null=null, header=header, quote=quote, escape=escape, force_quote=force_quote, encoding=encoding ) copy_stmt = 'COPY {tab}{cols} TO STDOUT {opts}'.format( tab=tabname, cols=cols, opts=opts) return await self._copy_out(copy_stmt, output, timeout) async def copy_from_query(self, query, *args, output, timeout=None, format=None, oids=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, encoding=None): """Copy the results of a query to a file or file-like object. :param str query: The query to copy the results of. :param args: Query arguments. :param output: A :term:`path-like object `, or a :term:`file-like object `, or a :term:`coroutine function ` that takes a ``bytes`` instance as a sole argument. :param float timeout: Optional timeout value in seconds. The remaining keyword arguments are ``COPY`` statement options, see `COPY statement documentation`_ for details. :return: The status string of the COPY command. Example: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> async def run(): ... con = await asyncpg.connect(user='postgres') ... result = await con.copy_from_query( ... 'SELECT foo, bar FROM mytable WHERE foo > $1', 10, ... output='file.csv', format='csv') ... print(result) ... >>> asyncio.run(run()) 'COPY 10' .. _`COPY statement documentation`: https://www.postgresql.org/docs/current/static/sql-copy.html .. versionadded:: 0.11.0 """ opts = self._format_copy_opts( format=format, oids=oids, delimiter=delimiter, null=null, header=header, quote=quote, escape=escape, force_quote=force_quote, encoding=encoding ) if args: query = await utils._mogrify(self, query, args) copy_stmt = 'COPY ({query}) TO STDOUT {opts}'.format( query=query, opts=opts) return await self._copy_out(copy_stmt, output, timeout) async def copy_to_table(self, table_name, *, source, columns=None, schema_name=None, timeout=None, format=None, oids=None, freeze=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, force_not_null=None, force_null=None, encoding=None, where=None): """Copy data to the specified table. :param str table_name: The name of the table to copy data to. :param source: A :term:`path-like object `, or a :term:`file-like object `, or an :term:`asynchronous iterable ` that returns ``bytes``, or an object supporting the :ref:`buffer protocol `. :param list columns: An optional list of column names to copy. :param str schema_name: An optional schema name to qualify the table. :param str where: An optional SQL expression used to filter rows when copying. .. note:: Usage of this parameter requires support for the ``COPY FROM ... WHERE`` syntax, introduced in PostgreSQL version 12. :param float timeout: Optional timeout value in seconds. The remaining keyword arguments are ``COPY`` statement options, see `COPY statement documentation`_ for details. :return: The status string of the COPY command. Example: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> async def run(): ... con = await asyncpg.connect(user='postgres') ... result = await con.copy_to_table( ... 'mytable', source='datafile.tbl') ... print(result) ... >>> asyncio.run(run()) 'COPY 140000' .. _`COPY statement documentation`: https://www.postgresql.org/docs/current/static/sql-copy.html .. versionadded:: 0.11.0 .. versionadded:: 0.29.0 Added the *where* parameter. """ tabname = utils._quote_ident(table_name) if schema_name: tabname = utils._quote_ident(schema_name) + '.' + tabname if columns: cols = '({})'.format( ', '.join(utils._quote_ident(c) for c in columns)) else: cols = '' cond = self._format_copy_where(where) opts = self._format_copy_opts( format=format, oids=oids, freeze=freeze, delimiter=delimiter, null=null, header=header, quote=quote, escape=escape, force_not_null=force_not_null, force_null=force_null, encoding=encoding ) copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format( tab=tabname, cols=cols, opts=opts, cond=cond) return await self._copy_in(copy_stmt, source, timeout) async def copy_records_to_table(self, table_name, *, records, columns=None, schema_name=None, timeout=None, where=None): """Copy a list of records to the specified table using binary COPY. :param str table_name: The name of the table to copy data to. :param records: An iterable returning row tuples to copy into the table. :term:`Asynchronous iterables ` are also supported. :param list columns: An optional list of column names to copy. :param str schema_name: An optional schema name to qualify the table. :param str where: An optional SQL expression used to filter rows when copying. .. note:: Usage of this parameter requires support for the ``COPY FROM ... WHERE`` syntax, introduced in PostgreSQL version 12. :param float timeout: Optional timeout value in seconds. :return: The status string of the COPY command. Example: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> async def run(): ... con = await asyncpg.connect(user='postgres') ... result = await con.copy_records_to_table( ... 'mytable', records=[ ... (1, 'foo', 'bar'), ... (2, 'ham', 'spam')]) ... print(result) ... >>> asyncio.run(run()) 'COPY 2' Asynchronous record iterables are also supported: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> async def run(): ... con = await asyncpg.connect(user='postgres') ... async def record_gen(size): ... for i in range(size): ... yield (i,) ... result = await con.copy_records_to_table( ... 'mytable', records=record_gen(100)) ... print(result) ... >>> asyncio.run(run()) 'COPY 100' .. versionadded:: 0.11.0 .. versionchanged:: 0.24.0 The ``records`` argument may be an asynchronous iterable. .. versionadded:: 0.29.0 Added the *where* parameter. """ tabname = utils._quote_ident(table_name) if schema_name: tabname = utils._quote_ident(schema_name) + '.' + tabname if columns: col_list = ', '.join(utils._quote_ident(c) for c in columns) cols = '({})'.format(col_list) else: col_list = '*' cols = '' intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format( tab=tabname, cols=col_list) intro_ps = await self.prepare(intro_query) cond = self._format_copy_where(where) opts = '(FORMAT binary)' copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format( tab=tabname, cols=cols, opts=opts, cond=cond) return await self._protocol.copy_in( copy_stmt, None, None, records, intro_ps._state, timeout) def _format_copy_where(self, where): if where and not self._server_caps.sql_copy_from_where: raise exceptions.UnsupportedServerFeatureError( 'the `where` parameter requires PostgreSQL 12 or later') if where: where_clause = 'WHERE ' + where else: where_clause = '' return where_clause def _format_copy_opts(self, *, format=None, oids=None, freeze=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, force_not_null=None, force_null=None, encoding=None): kwargs = dict(locals()) kwargs.pop('self') opts = [] if force_quote is not None and isinstance(force_quote, bool): kwargs.pop('force_quote') if force_quote: opts.append('FORCE_QUOTE *') for k, v in kwargs.items(): if v is not None: if k in ('force_not_null', 'force_null', 'force_quote'): v = '(' + ', '.join(utils._quote_ident(c) for c in v) + ')' elif k in ('oids', 'freeze', 'header'): v = str(v) else: v = utils._quote_literal(v) opts.append('{} {}'.format(k.upper(), v)) if opts: return '(' + ', '.join(opts) + ')' else: return '' async def _copy_out(self, copy_stmt, output, timeout): try: path = os.fspath(output) except TypeError: # output is not a path-like object path = None writer = None opened_by_us = False run_in_executor = self._loop.run_in_executor if path is not None: # a path f = await run_in_executor(None, open, path, 'wb') opened_by_us = True elif hasattr(output, 'write'): # file-like f = output elif callable(output): # assuming calling output returns an awaitable. writer = output else: raise TypeError( 'output is expected to be a file-like object, ' 'a path-like object or a coroutine function, ' 'not {}'.format(type(output).__name__) ) if writer is None: async def _writer(data): await run_in_executor(None, f.write, data) writer = _writer try: return await self._protocol.copy_out(copy_stmt, writer, timeout) finally: if opened_by_us: f.close() async def _copy_in(self, copy_stmt, source, timeout): try: path = os.fspath(source) except TypeError: # source is not a path-like object path = None f = None reader = None data = None opened_by_us = False run_in_executor = self._loop.run_in_executor if path is not None: # a path f = await run_in_executor(None, open, path, 'rb') opened_by_us = True elif hasattr(source, 'read'): # file-like f = source elif isinstance(source, collections.abc.AsyncIterable): # assuming calling output returns an awaitable. # copy_in() is designed to handle very large amounts of data, and # the source async iterable is allowed to return an arbitrary # amount of data on every iteration. reader = source else: # assuming source is an instance supporting the buffer protocol. data = source if f is not None: # Copying from a file-like object. class _Reader: def __aiter__(self): return self async def __anext__(self): data = await run_in_executor(None, f.read, 524288) if len(data) == 0: raise StopAsyncIteration else: return data reader = _Reader() try: return await self._protocol.copy_in( copy_stmt, reader, data, None, None, timeout) finally: if opened_by_us: await run_in_executor(None, f.close) async def set_type_codec(self, typename, *, schema='public', encoder, decoder, format='text'): """Set an encoder/decoder pair for the specified data type. :param typename: Name of the data type the codec is for. :param schema: Schema name of the data type the codec is for (defaults to ``'public'``) :param format: The type of the argument received by the *decoder* callback, and the type of the *encoder* callback return value. If *format* is ``'text'`` (the default), the exchange datum is a ``str`` instance containing valid text representation of the data type. If *format* is ``'binary'``, the exchange datum is a ``bytes`` instance containing valid _binary_ representation of the data type. If *format* is ``'tuple'``, the exchange datum is a type-specific ``tuple`` of values. The table below lists supported data types and their format for this mode. +-----------------+---------------------------------------------+ | Type | Tuple layout | +=================+=============================================+ | ``interval`` | (``months``, ``days``, ``microseconds``) | +-----------------+---------------------------------------------+ | ``date`` | (``date ordinal relative to Jan 1 2000``,) | | | ``-2^31`` for negative infinity timestamp | | | ``2^31-1`` for positive infinity timestamp. | +-----------------+---------------------------------------------+ | ``timestamp`` | (``microseconds relative to Jan 1 2000``,) | | | ``-2^63`` for negative infinity timestamp | | | ``2^63-1`` for positive infinity timestamp. | +-----------------+---------------------------------------------+ | ``timestamp | (``microseconds relative to Jan 1 2000 | | with time zone``| UTC``,) | | | ``-2^63`` for negative infinity timestamp | | | ``2^63-1`` for positive infinity timestamp. | +-----------------+---------------------------------------------+ | ``time`` | (``microseconds``,) | +-----------------+---------------------------------------------+ | ``time with | (``microseconds``, | | time zone`` | ``time zone offset in seconds``) | +-----------------+---------------------------------------------+ | any composite | Composite value elements | | type | | +-----------------+---------------------------------------------+ :param encoder: Callable accepting a Python object as a single argument and returning a value encoded according to *format*. :param decoder: Callable accepting a single argument encoded according to *format* and returning a decoded Python object. Example: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> import datetime >>> from dateutil.relativedelta import relativedelta >>> async def run(): ... con = await asyncpg.connect(user='postgres') ... def encoder(delta): ... ndelta = delta.normalized() ... return (ndelta.years * 12 + ndelta.months, ... ndelta.days, ... ((ndelta.hours * 3600 + ... ndelta.minutes * 60 + ... ndelta.seconds) * 1000000 + ... ndelta.microseconds)) ... def decoder(tup): ... return relativedelta(months=tup[0], days=tup[1], ... microseconds=tup[2]) ... await con.set_type_codec( ... 'interval', schema='pg_catalog', encoder=encoder, ... decoder=decoder, format='tuple') ... result = await con.fetchval( ... "SELECT '2 years 3 mons 1 day'::interval") ... print(result) ... print(datetime.datetime(2002, 1, 1) + result) ... >>> asyncio.run(run()) relativedelta(years=+2, months=+3, days=+1) 2004-04-02 00:00:00 .. versionadded:: 0.12.0 Added the ``format`` keyword argument and support for 'tuple' format. .. versionchanged:: 0.12.0 The ``binary`` keyword argument is deprecated in favor of ``format``. .. versionchanged:: 0.13.0 The ``binary`` keyword argument was removed in favor of ``format``. .. versionchanged:: 0.29.0 Custom codecs for composite types are now supported with ``format='tuple'``. .. note:: It is recommended to use the ``'binary'`` or ``'tuple'`` *format* whenever possible and if the underlying type supports it. Asyncpg currently does not support text I/O for composite and range types, and some other functionality, such as :meth:`Connection.copy_to_table`, does not support types with text codecs. """ self._check_open() settings = self._protocol.get_settings() typeinfo = await self._introspect_type(typename, schema) full_typeinfos = [] if introspection.is_scalar_type(typeinfo): kind = 'scalar' elif introspection.is_composite_type(typeinfo): if format != 'tuple': raise exceptions.UnsupportedClientFeatureError( 'only tuple-format codecs can be used on composite types', hint="Use `set_type_codec(..., format='tuple')` and " "pass/interpret data as a Python tuple. See an " "example at https://magicstack.github.io/asyncpg/" "current/usage.html#example-decoding-complex-types", ) kind = 'composite' full_typeinfos, _ = await self._introspect_types( (typeinfo['oid'],), 10) else: raise exceptions.InterfaceError( f'cannot use custom codec on type {schema}.{typename}: ' f'it is neither a scalar type nor a composite type' ) if introspection.is_domain_type(typeinfo): raise exceptions.UnsupportedClientFeatureError( 'custom codecs on domain types are not supported', hint='Set the codec on the base type.', detail=( 'PostgreSQL does not distinguish domains from ' 'their base types in query results at the protocol level.' ) ) oid = typeinfo['oid'] settings.add_python_codec( oid, typename, schema, full_typeinfos, kind, encoder, decoder, format) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() async def reset_type_codec(self, typename, *, schema='public'): """Reset *typename* codec to the default implementation. :param typename: Name of the data type the codec is for. :param schema: Schema name of the data type the codec is for (defaults to ``'public'``) .. versionadded:: 0.12.0 """ typeinfo = await self._introspect_type(typename, schema) self._protocol.get_settings().remove_python_codec( typeinfo['oid'], typename, schema) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() async def set_builtin_type_codec(self, typename, *, schema='public', codec_name, format=None): """Set a builtin codec for the specified scalar data type. This method has two uses. The first is to register a builtin codec for an extension type without a stable OID, such as 'hstore'. The second use is to declare that an extension type or a user-defined type is wire-compatible with a certain builtin data type and should be exchanged as such. :param typename: Name of the data type the codec is for. :param schema: Schema name of the data type the codec is for (defaults to ``'public'``). :param codec_name: The name of the builtin codec to use for the type. This should be either the name of a known core type (such as ``"int"``), or the name of a supported extension type. Currently, the only supported extension type is ``"pg_contrib.hstore"``. :param format: If *format* is ``None`` (the default), all formats supported by the target codec are declared to be supported for *typename*. If *format* is ``'text'`` or ``'binary'``, then only the specified format is declared to be supported for *typename*. .. versionchanged:: 0.18.0 The *codec_name* argument can be the name of any known core data type. Added the *format* keyword argument. """ self._check_open() typeinfo = await self._introspect_type(typename, schema) if not introspection.is_scalar_type(typeinfo): raise exceptions.InterfaceError( 'cannot alias non-scalar type {}.{}'.format( schema, typename)) oid = typeinfo['oid'] self._protocol.get_settings().set_builtin_type_codec( oid, typename, schema, 'scalar', codec_name, format) # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() def is_closed(self): """Return ``True`` if the connection is closed, ``False`` otherwise. :return bool: ``True`` if the connection is closed, ``False`` otherwise. """ return self._aborted or not self._protocol.is_connected() async def close(self, *, timeout=None): """Close the connection gracefully. :param float timeout: Optional timeout value in seconds. .. versionchanged:: 0.14.0 Added the *timeout* parameter. """ try: if not self.is_closed(): await self._protocol.close(timeout) except (Exception, asyncio.CancelledError): # If we fail to close gracefully, abort the connection. self._abort() raise finally: self._cleanup() def terminate(self): """Terminate the connection without waiting for pending data.""" if not self.is_closed(): self._abort() self._cleanup() async def _reset(self): self._check_open() self._listeners.clear() self._log_listeners.clear() if self._protocol.is_in_transaction() or self._top_xact is not None: if self._top_xact is None or not self._top_xact._managed: # Managed transactions are guaranteed to __aexit__ # correctly. self._loop.call_exception_handler({ 'message': 'Resetting connection with an ' 'active transaction {!r}'.format(self) }) self._top_xact = None await self.execute("ROLLBACK") async def reset(self, *, timeout=None): """Reset the connection state. Calling this will reset the connection session state to a state resembling that of a newly obtained connection. Namely, an open transaction (if any) is rolled back, open cursors are closed, all `LISTEN `_ registrations are removed, all session configuration variables are reset to their default values, and all advisory locks are released. Note that the above describes the default query returned by :meth:`Connection.get_reset_query`. If one overloads the method by subclassing ``Connection``, then this method will do whatever the overloaded method returns, except open transactions are always terminated and any callbacks registered by :meth:`Connection.add_listener` or :meth:`Connection.add_log_listener` are removed. :param float timeout: A timeout for resetting the connection. If not specified, defaults to no timeout. """ async with compat.timeout(timeout): await self._reset() reset_query = self.get_reset_query() if reset_query: await self.execute(reset_query) def _abort(self): # Put the connection into the aborted state. self._aborted = True self._protocol.abort() self._protocol = None def _cleanup(self): self._call_termination_listeners() # Free the resources associated with this connection. # This must be called when a connection is terminated. if self._proxy is not None: # Connection is a member of a pool, so let the pool # know that this connection is dead. self._proxy._holder._release_on_close() self._mark_stmts_as_closed() self._listeners.clear() self._log_listeners.clear() self._query_loggers.clear() self._clean_tasks() def _clean_tasks(self): # Wrap-up any remaining tasks associated with this connection. if self._cancellations: for fut in self._cancellations: if not fut.done(): fut.cancel() self._cancellations.clear() def _check_open(self): if self.is_closed(): raise exceptions.InterfaceError('connection is closed') def _get_unique_id(self, prefix): global _uid _uid += 1 return '__asyncpg_{}_{:x}__'.format(prefix, _uid) def _mark_stmts_as_closed(self): for stmt in self._stmt_cache.iter_statements(): stmt.mark_closed() for stmt in self._stmts_to_close: stmt.mark_closed() self._stmt_cache.clear() self._stmts_to_close.clear() def _maybe_gc_stmt(self, stmt): if ( stmt.refs == 0 and stmt.name and not self._stmt_cache.has( (stmt.query, stmt.record_class, stmt.ignore_custom_codec) ) ): # If low-level `stmt` isn't referenced from any high-level # `PreparedStatement` object and is not in the `_stmt_cache`: # # * mark it as closed, which will make it non-usable # for any `PreparedStatement` or for methods like # `Connection.fetch()`. # # * schedule it to be formally closed on the server. stmt.mark_closed() self._stmts_to_close.add(stmt) async def _cleanup_stmts(self): # Called whenever we create a new prepared statement in # `Connection._get_statement()` and `_stmts_to_close` is # not empty. to_close = self._stmts_to_close self._stmts_to_close = set() for stmt in to_close: # It is imperative that statements are cleaned properly, # so we ignore the timeout. await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT) async def _cancel(self, waiter): try: # Open new connection to the server await connect_utils._cancel( loop=self._loop, addr=self._addr, params=self._params, backend_pid=self._protocol.backend_pid, backend_secret=self._protocol.backend_secret) except ConnectionResetError as ex: # On some systems Postgres will reset the connection # after processing the cancellation command. if not waiter.done(): waiter.set_exception(ex) except asyncio.CancelledError: # There are two scenarios in which the cancellation # itself will be cancelled: 1) the connection is being closed, # 2) the event loop is being shut down. # In either case we do not care about the propagation of # the CancelledError, and don't want the loop to warn about # an unretrieved exception. pass except (Exception, asyncio.CancelledError) as ex: if not waiter.done(): waiter.set_exception(ex) finally: self._cancellations.discard( asyncio.current_task(self._loop)) if not waiter.done(): waiter.set_result(None) def _cancel_current_command(self, waiter): self._cancellations.add(self._loop.create_task(self._cancel(waiter))) def _process_log_message(self, fields, last_query): if not self._log_listeners: return message = exceptions.PostgresLogMessage.new(fields, query=last_query) con_ref = self._unwrap() for cb in self._log_listeners: if cb.is_async: self._loop.create_task(cb.cb(con_ref, message)) else: self._loop.call_soon(cb.cb, con_ref, message) def _call_termination_listeners(self): if not self._termination_listeners: return con_ref = self._unwrap() for cb in self._termination_listeners: if cb.is_async: self._loop.create_task(cb.cb(con_ref)) else: self._loop.call_soon(cb.cb, con_ref) self._termination_listeners.clear() def _process_notification(self, pid, channel, payload): if channel not in self._listeners: return con_ref = self._unwrap() for cb in self._listeners[channel]: if cb.is_async: self._loop.create_task(cb.cb(con_ref, pid, channel, payload)) else: self._loop.call_soon(cb.cb, con_ref, pid, channel, payload) def _unwrap(self): if self._proxy is None: con_ref = self else: # `_proxy` is not None when the connection is a member # of a connection pool. Which means that the user is working # with a `PoolConnectionProxy` instance, and expects to see it # (and not the actual Connection) in their event callbacks. con_ref = self._proxy return con_ref def get_reset_query(self): """Return the query sent to server on connection release. The query returned by this method is used by :meth:`Connection.reset`, which is, in turn, used by :class:`~asyncpg.pool.Pool` before making the connection available to another acquirer. .. versionadded:: 0.30.0 """ if self._reset_query is not None: return self._reset_query caps = self._server_caps _reset_query = [] if caps.advisory_locks: _reset_query.append('SELECT pg_advisory_unlock_all();') if caps.sql_close_all: _reset_query.append('CLOSE ALL;') if caps.notifications and caps.plpgsql: _reset_query.append('UNLISTEN *;') if caps.sql_reset: _reset_query.append('RESET ALL;') _reset_query = '\n'.join(_reset_query) self._reset_query = _reset_query return _reset_query def _set_proxy(self, proxy): if self._proxy is not None and proxy is not None: # Should not happen unless there is a bug in `Pool`. raise exceptions.InterfaceError( 'internal asyncpg error: connection is already proxied') self._proxy = proxy def _check_listeners(self, listeners, listener_type): if listeners: count = len(listeners) w = exceptions.InterfaceWarning( '{conn!r} is being released to the pool but has {c} active ' '{type} listener{s}'.format( conn=self, c=count, type=listener_type, s='s' if count > 1 else '')) warnings.warn(w) def _on_release(self, stacklevel=1): # Invalidate external references to the connection. self._pool_release_ctr += 1 # Called when the connection is about to be released to the pool. # Let's check that the user has not left any listeners on it. self._check_listeners( list(itertools.chain.from_iterable(self._listeners.values())), 'notification') self._check_listeners( self._log_listeners, 'log') def _drop_local_statement_cache(self): self._stmt_cache.clear() def _drop_global_statement_cache(self): if self._proxy is not None: # This connection is a member of a pool, so we delegate # the cache drop to the pool. pool = self._proxy._holder._pool pool._drop_statement_cache() else: self._drop_local_statement_cache() def _drop_local_type_cache(self): self._protocol.get_settings().clear_type_cache() def _drop_global_type_cache(self): if self._proxy is not None: # This connection is a member of a pool, so we delegate # the cache drop to the pool. pool = self._proxy._holder._pool pool._drop_type_cache() else: self._drop_local_type_cache() async def reload_schema_state(self): """Indicate that the database schema information must be reloaded. For performance reasons, asyncpg caches certain aspects of the database schema, such as the layout of composite types. Consequently, when the database schema changes, and asyncpg is not able to gracefully recover from an error caused by outdated schema assumptions, an :exc:`~asyncpg.exceptions.OutdatedSchemaCacheError` is raised. To prevent the exception, this method may be used to inform asyncpg that the database schema has changed. Example: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> async def change_type(con): ... result = await con.fetch('SELECT id, info FROM tbl') ... # Change composite's attribute type "int"=>"text" ... await con.execute('ALTER TYPE custom DROP ATTRIBUTE y') ... await con.execute('ALTER TYPE custom ADD ATTRIBUTE y text') ... await con.reload_schema_state() ... for id_, info in result: ... new = (info['x'], str(info['y'])) ... await con.execute( ... 'UPDATE tbl SET info=$2 WHERE id=$1', id_, new) ... >>> async def run(): ... # Initial schema: ... # CREATE TYPE custom AS (x int, y int); ... # CREATE TABLE tbl(id int, info custom); ... con = await asyncpg.connect(user='postgres') ... async with con.transaction(): ... # Prevent concurrent changes in the table ... await con.execute('LOCK TABLE tbl') ... await change_type(con) ... >>> asyncio.run(run()) .. versionadded:: 0.14.0 """ self._drop_global_type_cache() self._drop_global_statement_cache() async def _execute( self, query, args, limit, timeout, *, return_status=False, ignore_custom_codec=False, record_class=None ): with self._stmt_exclusive_section: result, _ = await self.__execute( query, args, limit, timeout, return_status=return_status, record_class=record_class, ignore_custom_codec=ignore_custom_codec, ) return result @contextlib.contextmanager def query_logger(self, callback): """Context manager that adds `callback` to the list of query loggers, and removes it upon exit. :param callable callback: A callable or a coroutine function receiving one argument: **record**, a LoggedQuery containing `query`, `args`, `timeout`, `elapsed`, `exception`, `conn_addr`, and `conn_params`. Example: .. code-block:: pycon >>> class QuerySaver: def __init__(self): self.queries = [] def __call__(self, record): self.queries.append(record.query) >>> with con.query_logger(QuerySaver()): >>> await con.execute("SELECT 1") >>> print(log.queries) ['SELECT 1'] .. versionadded:: 0.29.0 """ self.add_query_logger(callback) yield self.remove_query_logger(callback) @contextlib.contextmanager def _time_and_log(self, query, args, timeout): start = time.monotonic() exception = None try: yield except BaseException as ex: exception = ex raise finally: elapsed = time.monotonic() - start record = LoggedQuery( query=query, args=args, timeout=timeout, elapsed=elapsed, exception=exception, conn_addr=self._addr, conn_params=self._params, ) for cb in self._query_loggers: if cb.is_async: self._loop.create_task(cb.cb(record)) else: self._loop.call_soon(cb.cb, record) async def __execute( self, query, args, limit, timeout, *, return_status=False, ignore_custom_codec=False, record_class=None ): executor = lambda stmt, timeout: self._protocol.bind_execute( state=stmt, args=args, portal_name='', limit=limit, return_extra=return_status, timeout=timeout, ) timeout = self._protocol._get_timeout(timeout) if self._query_loggers: with self._time_and_log(query, args, timeout): result, stmt = await self._do_execute( query, executor, timeout, record_class=record_class, ignore_custom_codec=ignore_custom_codec, ) else: result, stmt = await self._do_execute( query, executor, timeout, record_class=record_class, ignore_custom_codec=ignore_custom_codec, ) return result, stmt async def _executemany( self, query, args, timeout, return_rows=False, record_class=None, ): executor = lambda stmt, timeout: self._protocol.bind_execute_many( state=stmt, args=args, portal_name='', timeout=timeout, return_rows=return_rows, ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: with self._time_and_log(query, args, timeout): result, _ = await self._do_execute( query, executor, timeout, record_class=record_class ) return result async def _do_execute( self, query, executor, timeout, retry=True, *, ignore_custom_codec=False, record_class=None ): if timeout is None: stmt = await self._get_statement( query, None, record_class=record_class, ignore_custom_codec=ignore_custom_codec, ) else: before = time.monotonic() stmt = await self._get_statement( query, timeout, record_class=record_class, ignore_custom_codec=ignore_custom_codec, ) after = time.monotonic() timeout -= after - before before = after try: if timeout is None: result = await executor(stmt, None) else: try: result = await executor(stmt, timeout) finally: after = time.monotonic() timeout -= after - before except exceptions.OutdatedSchemaCacheError: # This exception is raised when we detect a difference between # cached type's info and incoming tuple from the DB (when a type is # changed by the ALTER TYPE). # It is not possible to recover (the statement is already done at # the server's side), the only way is to drop our caches and # reraise the exception to the caller. await self.reload_schema_state() raise except exceptions.InvalidCachedStatementError: # PostgreSQL will raise an exception when it detects # that the result type of the query has changed from # when the statement was prepared. This may happen, # for example, after an ALTER TABLE or SET search_path. # # When this happens, and there is no transaction running, # we can simply re-prepare the statement and try once # again. We deliberately retry only once as this is # supposed to be a rare occurrence. # # If the transaction _is_ running, this error will put it # into an error state, and we have no choice but to # re-raise the exception. # # In either case we clear the statement cache for this # connection and all other connections of the pool this # connection belongs to (if any). # # See https://github.com/MagicStack/asyncpg/issues/72 # and https://github.com/MagicStack/asyncpg/issues/76 # for discussion. # self._drop_global_statement_cache() if self._protocol.is_in_transaction() or not retry: raise else: return await self._do_execute( query, executor, timeout, retry=False) return result, stmt async def connect(dsn=None, *, host=None, port=None, user=None, password=None, passfile=None, service=None, servicefile=None, database=None, loop=None, timeout=60, statement_cache_size=100, max_cached_statement_lifetime=300, max_cacheable_statement_size=1024 * 15, command_timeout=None, ssl=None, direct_tls=None, connection_class=Connection, record_class=protocol.Record, server_settings=None, target_session_attrs=None, krbsrvname=None, gsslib=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection URI in *dsn*, or as specific keyword arguments, or both. If both *dsn* and keyword arguments are specified, the latter override the corresponding values parsed from the connection URI. The default values for the majority of arguments can be specified using `environment variables `_. Returns a new :class:`~asyncpg.connection.Connection` object. :param dsn: Connection arguments specified using as a single string in the `libpq connection URI format`_: ``postgres://user:password@host:port/database?option=value``. The following options are recognized by asyncpg: ``host``, ``port``, ``user``, ``database`` (or ``dbname``), ``password``, ``passfile``, ``sslmode``, ``sslcert``, ``sslkey``, ``sslrootcert``, and ``sslcrl``. Unlike libpq, asyncpg will treat unrecognized options as `server settings`_ to be used for the connection. .. note:: The URI must be *valid*, which means that all components must be properly quoted with :py:func:`urllib.parse.quote_plus`, and any literal IPv6 addresses must be enclosed in square brackets. For example: .. code-block:: text postgres://dbuser@[fe80::1ff:fe23:4567:890a%25eth0]/dbname :param host: Database host address as one of the following: - an IP address or a domain name; - an absolute path to the directory containing the database server Unix-domain socket (not supported on Windows); - a sequence of any of the above, in which case the addresses will be tried in order, and the first successful connection will be returned. If not specified, asyncpg will try the following, in order: - host address(es) parsed from the *dsn* argument, - the value of the ``PGHOST`` environment variable, - on Unix, common directories used for PostgreSQL Unix-domain sockets: ``"/run/postgresql"``, ``"/var/run/postgresl"``, ``"/var/pgsql_socket"``, ``"/private/tmp"``, and ``"/tmp"``, - ``"localhost"``. :param port: Port number to connect to at the server host (or Unix-domain socket file extension). If multiple host addresses were specified, this parameter may specify a sequence of port numbers of the same length as the host sequence, or it may specify a single port number to be used for all host addresses. If not specified, the value parsed from the *dsn* argument is used, or the value of the ``PGPORT`` environment variable, or ``5432`` if neither is specified. :param user: The name of the database role used for authentication. If not specified, the value parsed from the *dsn* argument is used, or the value of the ``PGUSER`` environment variable, or the operating system name of the user running the application. :param database: The name of the database to connect to. If not specified, the value parsed from the *dsn* argument is used, or the value of the ``PGDATABASE`` environment variable, or the computed value of the *user* argument. :param password: Password to be used for authentication, if the server requires one. If not specified, the value parsed from the *dsn* argument is used, or the value of the ``PGPASSWORD`` environment variable. Note that the use of the environment variable is discouraged as other users and applications may be able to read it without needing specific privileges. It is recommended to use *passfile* instead. Password may be either a string, or a callable that returns a string. If a callable is provided, it will be called each time a new connection is established. :param passfile: The name of the file used to store passwords (defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf`` on Windows). :param service: The name of the postgres connection service stored in the postgres connection service file. :param servicefile: The location of the connnection service file used to store connection parameters. :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. :param float timeout: Connection timeout in seconds. :param int statement_cache_size: The size of prepared statement LRU cache. Pass ``0`` to disable the cache. :param int max_cached_statement_lifetime: The maximum time in seconds a prepared statement will stay in the cache. Pass ``0`` to allow statements be cached indefinitely. :param int max_cacheable_statement_size: The maximum size of a statement that can be cached (15KiB by default). Pass ``0`` to allow all statements to be cached regardless of their size. :param float command_timeout: The default timeout for operations on this connection (the default is ``None``: no timeout). :param ssl: Pass ``True`` or an `ssl.SSLContext `_ instance to require an SSL connection. If ``True``, a default SSL context returned by `ssl.create_default_context() `_ will be used. The value can also be one of the following strings: - ``'disable'`` - SSL is disabled (equivalent to ``False``) - ``'prefer'`` - try SSL first, fallback to non-SSL connection if SSL connection fails - ``'allow'`` - try without SSL first, then retry with SSL if the first attempt fails. - ``'require'`` - only try an SSL connection. Certificate verification errors are ignored - ``'verify-ca'`` - only try an SSL connection, and verify that the server certificate is issued by a trusted certificate authority (CA) - ``'verify-full'`` - only try an SSL connection, verify that the server certificate is issued by a trusted CA and that the requested server host name matches that in the certificate. The default is ``'prefer'``: try an SSL connection and fallback to non-SSL connection if that fails. .. note:: *ssl* is ignored for Unix domain socket communication. Example of programmatic SSL context configuration that is equivalent to ``sslmode=verify-full&sslcert=..&sslkey=..&sslrootcert=..``: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> import ssl >>> async def main(): ... # Load CA bundle for server certificate verification, ... # equivalent to sslrootcert= in DSN. ... sslctx = ssl.create_default_context( ... ssl.Purpose.SERVER_AUTH, ... cafile="path/to/ca_bundle.pem") ... # If True, equivalent to sslmode=verify-full, if False: ... # sslmode=verify-ca. ... sslctx.check_hostname = True ... # Load client certificate and private key for client ... # authentication, equivalent to sslcert= and sslkey= in ... # DSN. ... sslctx.load_cert_chain( ... "path/to/client.cert", ... keyfile="path/to/client.key", ... ) ... con = await asyncpg.connect(user='postgres', ssl=sslctx) ... await con.close() >>> asyncio.run(main()) Example of programmatic SSL context configuration that is equivalent to ``sslmode=require`` (no server certificate or host verification): .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> import ssl >>> async def main(): ... sslctx = ssl.create_default_context( ... ssl.Purpose.SERVER_AUTH) ... sslctx.check_hostname = False ... sslctx.verify_mode = ssl.CERT_NONE ... con = await asyncpg.connect(user='postgres', ssl=sslctx) ... await con.close() >>> asyncio.run(main()) :param bool direct_tls: Pass ``True`` to skip PostgreSQL STARTTLS mode and perform a direct SSL connection. Must be used alongside ``ssl`` param. :param dict server_settings: An optional dict of server runtime parameters. Refer to PostgreSQL documentation for a `list of supported options `_. :param type connection_class: Class of the returned connection object. Must be a subclass of :class:`~asyncpg.connection.Connection`. :param type record_class: If specified, the class to use for records returned by queries on this connection object. Must be a subclass of :class:`~asyncpg.Record`. :param SessionAttribute target_session_attrs: If specified, check that the host has the correct attribute. Can be one of: - ``"any"`` - the first successfully connected host - ``"primary"`` - the host must NOT be in hot standby mode - ``"standby"`` - the host must be in hot standby mode - ``"read-write"`` - the host must allow writes - ``"read-only"`` - the host most NOT allow writes - ``"prefer-standby"`` - first try to find a standby host, but if none of the listed hosts is a standby server, return any of them. If not specified, the value parsed from the *dsn* argument is used, or the value of the ``PGTARGETSESSIONATTRS`` environment variable, or ``"any"`` if neither is specified. :param str krbsrvname: Kerberos service name to use when authenticating with GSSAPI. This must match the server configuration. Defaults to 'postgres'. :param str gsslib: GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi' or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise. :return: A :class:`~asyncpg.connection.Connection` instance. Example: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> async def run(): ... con = await asyncpg.connect(user='postgres') ... types = await con.fetch('SELECT * FROM pg_type') ... print(types) ... >>> asyncio.run(run()) [= 0 self._max_size = new_size self._maybe_cleanup() def get_max_lifetime(self): return self._max_lifetime def set_max_lifetime(self, new_lifetime): assert new_lifetime >= 0 self._max_lifetime = new_lifetime for entry in self._entries.values(): # For every entry cancel the existing callback # and setup a new one if necessary. self._set_entry_timeout(entry) def get(self, query, *, promote=True): if not self._max_size: # The cache is disabled. return entry = self._entries.get(query) # type: _StatementCacheEntry if entry is None: return if entry._statement.closed: # Happens in unittests when we call `stmt._state.mark_closed()` # manually or when a prepared statement closes itself on type # cache error. self._entries.pop(query) self._clear_entry_callback(entry) return if promote: # `promote` is `False` when `get()` is called by `has()`. self._entries.move_to_end(query, last=True) return entry._statement def has(self, query): return self.get(query, promote=False) is not None def put(self, query, statement): if not self._max_size: # The cache is disabled. return self._entries[query] = self._new_entry(query, statement) # Check if the cache is bigger than max_size and trim it # if necessary. self._maybe_cleanup() def iter_statements(self): return (e._statement for e in self._entries.values()) def clear(self): # Store entries for later. entries = tuple(self._entries.values()) # Clear the entries dict. self._entries.clear() # Make sure that we cancel all scheduled callbacks # and call on_remove callback for each entry. for entry in entries: self._clear_entry_callback(entry) self._on_remove(entry._statement) def _set_entry_timeout(self, entry): # Clear the existing timeout. self._clear_entry_callback(entry) # Set the new timeout if it's not 0. if self._max_lifetime: entry._cleanup_cb = self._loop.call_later( self._max_lifetime, self._on_entry_expired, entry) def _new_entry(self, query, statement): entry = _StatementCacheEntry(self, query, statement) self._set_entry_timeout(entry) return entry def _on_entry_expired(self, entry): # `call_later` callback, called when an entry stayed longer # than `self._max_lifetime`. if self._entries.get(entry._query) is entry: self._entries.pop(entry._query) self._on_remove(entry._statement) def _clear_entry_callback(self, entry): if entry._cleanup_cb is not None: entry._cleanup_cb.cancel() def _maybe_cleanup(self): # Delete cache entries until the size of the cache is `max_size`. while len(self._entries) > self._max_size: old_query, old_entry = self._entries.popitem(last=False) self._clear_entry_callback(old_entry) # Let the connection know that the statement was removed # from the cache. self._on_remove(old_entry._statement) class _Callback(typing.NamedTuple): cb: typing.Callable[..., None] is_async: bool @classmethod def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': if inspect.iscoroutinefunction(cb): is_async = True elif callable(cb): is_async = False else: raise exceptions.InterfaceError( 'expected a callable or an `async def` function,' 'got {!r}'.format(cb) ) return cls(cb, is_async) class _Atomic: __slots__ = ('_acquired',) def __init__(self): self._acquired = 0 def __enter__(self): if self._acquired: raise exceptions.InterfaceError( 'cannot perform operation: another operation is in progress') self._acquired = 1 def __exit__(self, t, e, tb): self._acquired = 0 class _ConnectionProxy: # Base class to enable `isinstance(Connection)` check. __slots__ = () LoggedQuery = collections.namedtuple( 'LoggedQuery', ['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr', 'conn_params']) LoggedQuery.__doc__ = 'Log record of an executed query.' ServerCapabilities = collections.namedtuple( 'ServerCapabilities', ['advisory_locks', 'notifications', 'plpgsql', 'sql_reset', 'sql_close_all', 'sql_copy_from_where', 'jit']) ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.' def _detect_server_capabilities(server_version, connection_settings): if hasattr(connection_settings, 'padb_revision'): # Amazon Redshift detected. advisory_locks = False notifications = False plpgsql = False sql_reset = True sql_close_all = False jit = False sql_copy_from_where = False elif hasattr(connection_settings, 'crdb_version'): # CockroachDB detected. advisory_locks = False notifications = False plpgsql = False sql_reset = False sql_close_all = False jit = False sql_copy_from_where = False elif hasattr(connection_settings, 'crate_version'): # CrateDB detected. advisory_locks = False notifications = False plpgsql = False sql_reset = False sql_close_all = False jit = False sql_copy_from_where = False else: # Standard PostgreSQL server assumed. advisory_locks = True notifications = True plpgsql = True sql_reset = True sql_close_all = True jit = server_version >= (11, 0) sql_copy_from_where = server_version.major >= 12 return ServerCapabilities( advisory_locks=advisory_locks, notifications=notifications, plpgsql=plpgsql, sql_reset=sql_reset, sql_close_all=sql_close_all, sql_copy_from_where=sql_copy_from_where, jit=jit, ) def _extract_stack(limit=10): """Replacement for traceback.extract_stack() that only does the necessary work for asyncio debug mode. """ frame = sys._getframe().f_back try: stack = traceback.StackSummary.extract( traceback.walk_stack(frame), lookup_lines=False) finally: del frame apg_path = asyncpg.__path__[0] i = 0 while i < len(stack) and stack[i][0].startswith(apg_path): i += 1 stack = stack[i:i + limit] stack.reverse() return ''.join(traceback.format_list(stack)) def _check_record_class(record_class): if record_class is protocol.Record: pass elif ( isinstance(record_class, type) and issubclass(record_class, protocol.Record) ): if ( record_class.__new__ is not protocol.Record.__new__ or record_class.__init__ is not protocol.Record.__init__ ): raise exceptions.InterfaceError( 'record_class must not redefine __new__ or __init__' ) else: raise exceptions.InterfaceError( 'record_class is expected to be a subclass of ' 'asyncpg.Record, got {!r}'.format(record_class) ) def _weak_maybe_gc_stmt(weak_ref, stmt): self = weak_ref() if self is not None: self._maybe_gc_stmt(stmt) _uid = 0 ================================================ FILE: asyncpg/connresource.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import functools from . import exceptions def guarded(meth): """A decorator to add a sanity check to ConnectionResource methods.""" @functools.wraps(meth) def _check(self, *args, **kwargs): self._check_conn_validity(meth.__name__) return meth(self, *args, **kwargs) return _check class ConnectionResource: __slots__ = ('_connection', '_con_release_ctr') def __init__(self, connection): self._connection = connection self._con_release_ctr = connection._pool_release_ctr def _check_conn_validity(self, meth_name): con_release_ctr = self._connection._pool_release_ctr if con_release_ctr != self._con_release_ctr: raise exceptions.InterfaceError( 'cannot call {}.{}(): ' 'the underlying connection has been released back ' 'to the pool'.format(self.__class__.__name__, meth_name)) if self._connection.is_closed(): raise exceptions.InterfaceError( 'cannot call {}.{}(): ' 'the underlying connection is closed'.format( self.__class__.__name__, meth_name)) ================================================ FILE: asyncpg/cursor.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import collections from . import connresource from . import exceptions class CursorFactory(connresource.ConnectionResource): """A cursor interface for the results of a query. A cursor interface can be used to initiate efficient traversal of the results of a large query. """ __slots__ = ( '_state', '_args', '_prefetch', '_query', '_timeout', '_record_class', ) def __init__( self, connection, query, state, args, prefetch, timeout, record_class ): super().__init__(connection) self._args = args self._prefetch = prefetch self._query = query self._timeout = timeout self._state = state self._record_class = record_class if state is not None: state.attach() @connresource.guarded def __aiter__(self): prefetch = 50 if self._prefetch is None else self._prefetch return CursorIterator( self._connection, self._query, self._state, self._args, self._record_class, prefetch, self._timeout, ) @connresource.guarded def __await__(self): if self._prefetch is not None: raise exceptions.InterfaceError( 'prefetch argument can only be specified for iterable cursor') cursor = Cursor( self._connection, self._query, self._state, self._args, self._record_class, ) return cursor._init(self._timeout).__await__() def __del__(self): if self._state is not None: self._state.detach() self._connection._maybe_gc_stmt(self._state) class BaseCursor(connresource.ConnectionResource): __slots__ = ( '_state', '_args', '_portal_name', '_exhausted', '_query', '_record_class', ) def __init__(self, connection, query, state, args, record_class): super().__init__(connection) self._args = args self._state = state if state is not None: state.attach() self._portal_name = None self._exhausted = False self._query = query self._record_class = record_class def _check_ready(self): if self._state is None: raise exceptions.InterfaceError( 'cursor: no associated prepared statement') if self._state.closed: raise exceptions.InterfaceError( 'cursor: the prepared statement is closed') if not self._connection._top_xact: raise exceptions.NoActiveSQLTransactionError( 'cursor cannot be created outside of a transaction') async def _bind_exec(self, n, timeout): self._check_ready() if self._portal_name: raise exceptions.InterfaceError( 'cursor already has an open portal') con = self._connection protocol = con._protocol self._portal_name = con._get_unique_id('portal') buffer, _, self._exhausted = await protocol.bind_execute( self._state, self._args, self._portal_name, n, True, timeout) return buffer async def _bind(self, timeout): self._check_ready() if self._portal_name: raise exceptions.InterfaceError( 'cursor already has an open portal') con = self._connection protocol = con._protocol self._portal_name = con._get_unique_id('portal') buffer = await protocol.bind(self._state, self._args, self._portal_name, timeout) return buffer async def _exec(self, n, timeout): self._check_ready() if not self._portal_name: raise exceptions.InterfaceError( 'cursor does not have an open portal') protocol = self._connection._protocol buffer, _, self._exhausted = await protocol.execute( self._state, self._portal_name, n, True, timeout) return buffer async def _close_portal(self, timeout): self._check_ready() if not self._portal_name: raise exceptions.InterfaceError( 'cursor does not have an open portal') protocol = self._connection._protocol await protocol.close_portal(self._portal_name, timeout) self._portal_name = None def __repr__(self): attrs = [] if self._exhausted: attrs.append('exhausted') attrs.append('') # to separate from id if self.__class__.__module__.startswith('asyncpg.'): mod = 'asyncpg' else: mod = self.__class__.__module__ return '<{}.{} "{!s:.30}" {}{:#x}>'.format( mod, self.__class__.__name__, self._state.query, ' '.join(attrs), id(self)) def __del__(self): if self._state is not None: self._state.detach() self._connection._maybe_gc_stmt(self._state) class CursorIterator(BaseCursor): __slots__ = ('_buffer', '_prefetch', '_timeout') def __init__( self, connection, query, state, args, record_class, prefetch, timeout ): super().__init__(connection, query, state, args, record_class) if prefetch <= 0: raise exceptions.InterfaceError( 'prefetch argument must be greater than zero') self._buffer = collections.deque() self._prefetch = prefetch self._timeout = timeout @connresource.guarded def __aiter__(self): return self @connresource.guarded async def __anext__(self): if self._state is None: self._state = await self._connection._get_statement( self._query, self._timeout, named=True, record_class=self._record_class, ) self._state.attach() if not self._portal_name and not self._exhausted: buffer = await self._bind_exec(self._prefetch, self._timeout) self._buffer.extend(buffer) if not self._buffer and not self._exhausted: buffer = await self._exec(self._prefetch, self._timeout) self._buffer.extend(buffer) if self._portal_name and self._exhausted: await self._close_portal(self._timeout) if self._buffer: return self._buffer.popleft() raise StopAsyncIteration class Cursor(BaseCursor): """An open *portal* into the results of a query.""" __slots__ = () async def _init(self, timeout): if self._state is None: self._state = await self._connection._get_statement( self._query, timeout, named=True, record_class=self._record_class, ) self._state.attach() self._check_ready() await self._bind(timeout) return self @connresource.guarded async def fetch(self, n, *, timeout=None): r"""Return the next *n* rows as a list of :class:`Record` objects. :param float timeout: Optional timeout value in seconds. :return: A list of :class:`Record` instances. """ self._check_ready() if n <= 0: raise exceptions.InterfaceError('n must be greater than zero') if self._exhausted: return [] recs = await self._exec(n, timeout) if len(recs) < n: self._exhausted = True return recs @connresource.guarded async def fetchrow(self, *, timeout=None): r"""Return the next row. :param float timeout: Optional timeout value in seconds. :return: A :class:`Record` instance. """ self._check_ready() if self._exhausted: return None recs = await self._exec(1, timeout) if len(recs) < 1: self._exhausted = True return None return recs[0] @connresource.guarded async def forward(self, n, *, timeout=None) -> int: r"""Skip over the next *n* rows. :param float timeout: Optional timeout value in seconds. :return: A number of rows actually skipped over (<= *n*). """ self._check_ready() if n <= 0: raise exceptions.InterfaceError('n must be greater than zero') protocol = self._connection._protocol status = await protocol.query('MOVE FORWARD {:d} {}'.format( n, self._portal_name), timeout) advanced = int(status.split()[1]) if advanced < n: self._exhausted = True return advanced ================================================ FILE: asyncpg/exceptions/__init__.py ================================================ # GENERATED FROM postgresql/src/backend/utils/errcodes.txt # DO NOT MODIFY, use tools/generate_exceptions.py to update from ._base import * # NOQA from . import _base class PostgresWarning(_base.PostgresLogMessage, Warning): sqlstate = '01000' class DynamicResultSetsReturned(PostgresWarning): sqlstate = '0100C' class ImplicitZeroBitPadding(PostgresWarning): sqlstate = '01008' class NullValueEliminatedInSetFunction(PostgresWarning): sqlstate = '01003' class PrivilegeNotGranted(PostgresWarning): sqlstate = '01007' class PrivilegeNotRevoked(PostgresWarning): sqlstate = '01006' class StringDataRightTruncation(PostgresWarning): sqlstate = '01004' class DeprecatedFeature(PostgresWarning): sqlstate = '01P01' class NoData(PostgresWarning): sqlstate = '02000' class NoAdditionalDynamicResultSetsReturned(NoData): sqlstate = '02001' class SQLStatementNotYetCompleteError(_base.PostgresError): sqlstate = '03000' class PostgresConnectionError(_base.PostgresError): sqlstate = '08000' class ConnectionDoesNotExistError(PostgresConnectionError): sqlstate = '08003' class ConnectionFailureError(PostgresConnectionError): sqlstate = '08006' class ClientCannotConnectError(PostgresConnectionError): sqlstate = '08001' class ConnectionRejectionError(PostgresConnectionError): sqlstate = '08004' class TransactionResolutionUnknownError(PostgresConnectionError): sqlstate = '08007' class ProtocolViolationError(PostgresConnectionError): sqlstate = '08P01' class TriggeredActionError(_base.PostgresError): sqlstate = '09000' class FeatureNotSupportedError(_base.PostgresError): sqlstate = '0A000' class InvalidCachedStatementError(FeatureNotSupportedError): pass class InvalidTransactionInitiationError(_base.PostgresError): sqlstate = '0B000' class LocatorError(_base.PostgresError): sqlstate = '0F000' class InvalidLocatorSpecificationError(LocatorError): sqlstate = '0F001' class InvalidGrantorError(_base.PostgresError): sqlstate = '0L000' class InvalidGrantOperationError(InvalidGrantorError): sqlstate = '0LP01' class InvalidRoleSpecificationError(_base.PostgresError): sqlstate = '0P000' class DiagnosticsError(_base.PostgresError): sqlstate = '0Z000' class StackedDiagnosticsAccessedWithoutActiveHandlerError(DiagnosticsError): sqlstate = '0Z002' class InvalidArgumentForXqueryError(_base.PostgresError): sqlstate = '10608' class CaseNotFoundError(_base.PostgresError): sqlstate = '20000' class CardinalityViolationError(_base.PostgresError): sqlstate = '21000' class DataError(_base.PostgresError): sqlstate = '22000' class ArraySubscriptError(DataError): sqlstate = '2202E' class CharacterNotInRepertoireError(DataError): sqlstate = '22021' class DatetimeFieldOverflowError(DataError): sqlstate = '22008' class DivisionByZeroError(DataError): sqlstate = '22012' class ErrorInAssignmentError(DataError): sqlstate = '22005' class EscapeCharacterConflictError(DataError): sqlstate = '2200B' class IndicatorOverflowError(DataError): sqlstate = '22022' class IntervalFieldOverflowError(DataError): sqlstate = '22015' class InvalidArgumentForLogarithmError(DataError): sqlstate = '2201E' class InvalidArgumentForNtileFunctionError(DataError): sqlstate = '22014' class InvalidArgumentForNthValueFunctionError(DataError): sqlstate = '22016' class InvalidArgumentForPowerFunctionError(DataError): sqlstate = '2201F' class InvalidArgumentForWidthBucketFunctionError(DataError): sqlstate = '2201G' class InvalidCharacterValueForCastError(DataError): sqlstate = '22018' class InvalidDatetimeFormatError(DataError): sqlstate = '22007' class InvalidEscapeCharacterError(DataError): sqlstate = '22019' class InvalidEscapeOctetError(DataError): sqlstate = '2200D' class InvalidEscapeSequenceError(DataError): sqlstate = '22025' class NonstandardUseOfEscapeCharacterError(DataError): sqlstate = '22P06' class InvalidIndicatorParameterValueError(DataError): sqlstate = '22010' class InvalidParameterValueError(DataError): sqlstate = '22023' class InvalidPrecedingOrFollowingSizeError(DataError): sqlstate = '22013' class InvalidRegularExpressionError(DataError): sqlstate = '2201B' class InvalidRowCountInLimitClauseError(DataError): sqlstate = '2201W' class InvalidRowCountInResultOffsetClauseError(DataError): sqlstate = '2201X' class InvalidTablesampleArgumentError(DataError): sqlstate = '2202H' class InvalidTablesampleRepeatError(DataError): sqlstate = '2202G' class InvalidTimeZoneDisplacementValueError(DataError): sqlstate = '22009' class InvalidUseOfEscapeCharacterError(DataError): sqlstate = '2200C' class MostSpecificTypeMismatchError(DataError): sqlstate = '2200G' class NullValueNotAllowedError(DataError): sqlstate = '22004' class NullValueNoIndicatorParameterError(DataError): sqlstate = '22002' class NumericValueOutOfRangeError(DataError): sqlstate = '22003' class SequenceGeneratorLimitExceededError(DataError): sqlstate = '2200H' class StringDataLengthMismatchError(DataError): sqlstate = '22026' class StringDataRightTruncationError(DataError): sqlstate = '22001' class SubstringError(DataError): sqlstate = '22011' class TrimError(DataError): sqlstate = '22027' class UnterminatedCStringError(DataError): sqlstate = '22024' class ZeroLengthCharacterStringError(DataError): sqlstate = '2200F' class PostgresFloatingPointError(DataError): sqlstate = '22P01' class InvalidTextRepresentationError(DataError): sqlstate = '22P02' class InvalidBinaryRepresentationError(DataError): sqlstate = '22P03' class BadCopyFileFormatError(DataError): sqlstate = '22P04' class UntranslatableCharacterError(DataError): sqlstate = '22P05' class NotAnXmlDocumentError(DataError): sqlstate = '2200L' class InvalidXmlDocumentError(DataError): sqlstate = '2200M' class InvalidXmlContentError(DataError): sqlstate = '2200N' class InvalidXmlCommentError(DataError): sqlstate = '2200S' class InvalidXmlProcessingInstructionError(DataError): sqlstate = '2200T' class DuplicateJsonObjectKeyValueError(DataError): sqlstate = '22030' class InvalidArgumentForSQLJsonDatetimeFunctionError(DataError): sqlstate = '22031' class InvalidJsonTextError(DataError): sqlstate = '22032' class InvalidSQLJsonSubscriptError(DataError): sqlstate = '22033' class MoreThanOneSQLJsonItemError(DataError): sqlstate = '22034' class NoSQLJsonItemError(DataError): sqlstate = '22035' class NonNumericSQLJsonItemError(DataError): sqlstate = '22036' class NonUniqueKeysInAJsonObjectError(DataError): sqlstate = '22037' class SingletonSQLJsonItemRequiredError(DataError): sqlstate = '22038' class SQLJsonArrayNotFoundError(DataError): sqlstate = '22039' class SQLJsonMemberNotFoundError(DataError): sqlstate = '2203A' class SQLJsonNumberNotFoundError(DataError): sqlstate = '2203B' class SQLJsonObjectNotFoundError(DataError): sqlstate = '2203C' class TooManyJsonArrayElementsError(DataError): sqlstate = '2203D' class TooManyJsonObjectMembersError(DataError): sqlstate = '2203E' class SQLJsonScalarRequiredError(DataError): sqlstate = '2203F' class SQLJsonItemCannotBeCastToTargetTypeError(DataError): sqlstate = '2203G' class IntegrityConstraintViolationError(_base.PostgresError): sqlstate = '23000' class RestrictViolationError(IntegrityConstraintViolationError): sqlstate = '23001' class NotNullViolationError(IntegrityConstraintViolationError): sqlstate = '23502' class ForeignKeyViolationError(IntegrityConstraintViolationError): sqlstate = '23503' class UniqueViolationError(IntegrityConstraintViolationError): sqlstate = '23505' class CheckViolationError(IntegrityConstraintViolationError): sqlstate = '23514' class ExclusionViolationError(IntegrityConstraintViolationError): sqlstate = '23P01' class InvalidCursorStateError(_base.PostgresError): sqlstate = '24000' class InvalidTransactionStateError(_base.PostgresError): sqlstate = '25000' class ActiveSQLTransactionError(InvalidTransactionStateError): sqlstate = '25001' class BranchTransactionAlreadyActiveError(InvalidTransactionStateError): sqlstate = '25002' class HeldCursorRequiresSameIsolationLevelError(InvalidTransactionStateError): sqlstate = '25008' class InappropriateAccessModeForBranchTransactionError( InvalidTransactionStateError): sqlstate = '25003' class InappropriateIsolationLevelForBranchTransactionError( InvalidTransactionStateError): sqlstate = '25004' class NoActiveSQLTransactionForBranchTransactionError( InvalidTransactionStateError): sqlstate = '25005' class ReadOnlySQLTransactionError(InvalidTransactionStateError): sqlstate = '25006' class SchemaAndDataStatementMixingNotSupportedError( InvalidTransactionStateError): sqlstate = '25007' class NoActiveSQLTransactionError(InvalidTransactionStateError): sqlstate = '25P01' class InFailedSQLTransactionError(InvalidTransactionStateError): sqlstate = '25P02' class IdleInTransactionSessionTimeoutError(InvalidTransactionStateError): sqlstate = '25P03' class TransactionTimeoutError(InvalidTransactionStateError): sqlstate = '25P04' class InvalidSQLStatementNameError(_base.PostgresError): sqlstate = '26000' class TriggeredDataChangeViolationError(_base.PostgresError): sqlstate = '27000' class InvalidAuthorizationSpecificationError(_base.PostgresError): sqlstate = '28000' class InvalidPasswordError(InvalidAuthorizationSpecificationError): sqlstate = '28P01' class DependentPrivilegeDescriptorsStillExistError(_base.PostgresError): sqlstate = '2B000' class DependentObjectsStillExistError( DependentPrivilegeDescriptorsStillExistError): sqlstate = '2BP01' class InvalidTransactionTerminationError(_base.PostgresError): sqlstate = '2D000' class SQLRoutineError(_base.PostgresError): sqlstate = '2F000' class FunctionExecutedNoReturnStatementError(SQLRoutineError): sqlstate = '2F005' class ModifyingSQLDataNotPermittedError(SQLRoutineError): sqlstate = '2F002' class ProhibitedSQLStatementAttemptedError(SQLRoutineError): sqlstate = '2F003' class ReadingSQLDataNotPermittedError(SQLRoutineError): sqlstate = '2F004' class InvalidCursorNameError(_base.PostgresError): sqlstate = '34000' class ExternalRoutineError(_base.PostgresError): sqlstate = '38000' class ContainingSQLNotPermittedError(ExternalRoutineError): sqlstate = '38001' class ModifyingExternalRoutineSQLDataNotPermittedError(ExternalRoutineError): sqlstate = '38002' class ProhibitedExternalRoutineSQLStatementAttemptedError( ExternalRoutineError): sqlstate = '38003' class ReadingExternalRoutineSQLDataNotPermittedError(ExternalRoutineError): sqlstate = '38004' class ExternalRoutineInvocationError(_base.PostgresError): sqlstate = '39000' class InvalidSqlstateReturnedError(ExternalRoutineInvocationError): sqlstate = '39001' class NullValueInExternalRoutineNotAllowedError( ExternalRoutineInvocationError): sqlstate = '39004' class TriggerProtocolViolatedError(ExternalRoutineInvocationError): sqlstate = '39P01' class SrfProtocolViolatedError(ExternalRoutineInvocationError): sqlstate = '39P02' class EventTriggerProtocolViolatedError(ExternalRoutineInvocationError): sqlstate = '39P03' class SavepointError(_base.PostgresError): sqlstate = '3B000' class InvalidSavepointSpecificationError(SavepointError): sqlstate = '3B001' class InvalidCatalogNameError(_base.PostgresError): sqlstate = '3D000' class InvalidSchemaNameError(_base.PostgresError): sqlstate = '3F000' class TransactionRollbackError(_base.PostgresError): sqlstate = '40000' class TransactionIntegrityConstraintViolationError(TransactionRollbackError): sqlstate = '40002' class SerializationError(TransactionRollbackError): sqlstate = '40001' class StatementCompletionUnknownError(TransactionRollbackError): sqlstate = '40003' class DeadlockDetectedError(TransactionRollbackError): sqlstate = '40P01' class SyntaxOrAccessError(_base.PostgresError): sqlstate = '42000' class PostgresSyntaxError(SyntaxOrAccessError): sqlstate = '42601' class InsufficientPrivilegeError(SyntaxOrAccessError): sqlstate = '42501' class CannotCoerceError(SyntaxOrAccessError): sqlstate = '42846' class GroupingError(SyntaxOrAccessError): sqlstate = '42803' class WindowingError(SyntaxOrAccessError): sqlstate = '42P20' class InvalidRecursionError(SyntaxOrAccessError): sqlstate = '42P19' class InvalidForeignKeyError(SyntaxOrAccessError): sqlstate = '42830' class InvalidNameError(SyntaxOrAccessError): sqlstate = '42602' class NameTooLongError(SyntaxOrAccessError): sqlstate = '42622' class ReservedNameError(SyntaxOrAccessError): sqlstate = '42939' class DatatypeMismatchError(SyntaxOrAccessError): sqlstate = '42804' class IndeterminateDatatypeError(SyntaxOrAccessError): sqlstate = '42P18' class CollationMismatchError(SyntaxOrAccessError): sqlstate = '42P21' class IndeterminateCollationError(SyntaxOrAccessError): sqlstate = '42P22' class WrongObjectTypeError(SyntaxOrAccessError): sqlstate = '42809' class GeneratedAlwaysError(SyntaxOrAccessError): sqlstate = '428C9' class UndefinedColumnError(SyntaxOrAccessError): sqlstate = '42703' class UndefinedFunctionError(SyntaxOrAccessError): sqlstate = '42883' class UndefinedTableError(SyntaxOrAccessError): sqlstate = '42P01' class UndefinedParameterError(SyntaxOrAccessError): sqlstate = '42P02' class UndefinedObjectError(SyntaxOrAccessError): sqlstate = '42704' class DuplicateColumnError(SyntaxOrAccessError): sqlstate = '42701' class DuplicateCursorError(SyntaxOrAccessError): sqlstate = '42P03' class DuplicateDatabaseError(SyntaxOrAccessError): sqlstate = '42P04' class DuplicateFunctionError(SyntaxOrAccessError): sqlstate = '42723' class DuplicatePreparedStatementError(SyntaxOrAccessError): sqlstate = '42P05' class DuplicateSchemaError(SyntaxOrAccessError): sqlstate = '42P06' class DuplicateTableError(SyntaxOrAccessError): sqlstate = '42P07' class DuplicateAliasError(SyntaxOrAccessError): sqlstate = '42712' class DuplicateObjectError(SyntaxOrAccessError): sqlstate = '42710' class AmbiguousColumnError(SyntaxOrAccessError): sqlstate = '42702' class AmbiguousFunctionError(SyntaxOrAccessError): sqlstate = '42725' class AmbiguousParameterError(SyntaxOrAccessError): sqlstate = '42P08' class AmbiguousAliasError(SyntaxOrAccessError): sqlstate = '42P09' class InvalidColumnReferenceError(SyntaxOrAccessError): sqlstate = '42P10' class InvalidColumnDefinitionError(SyntaxOrAccessError): sqlstate = '42611' class InvalidCursorDefinitionError(SyntaxOrAccessError): sqlstate = '42P11' class InvalidDatabaseDefinitionError(SyntaxOrAccessError): sqlstate = '42P12' class InvalidFunctionDefinitionError(SyntaxOrAccessError): sqlstate = '42P13' class InvalidPreparedStatementDefinitionError(SyntaxOrAccessError): sqlstate = '42P14' class InvalidSchemaDefinitionError(SyntaxOrAccessError): sqlstate = '42P15' class InvalidTableDefinitionError(SyntaxOrAccessError): sqlstate = '42P16' class InvalidObjectDefinitionError(SyntaxOrAccessError): sqlstate = '42P17' class WithCheckOptionViolationError(_base.PostgresError): sqlstate = '44000' class InsufficientResourcesError(_base.PostgresError): sqlstate = '53000' class DiskFullError(InsufficientResourcesError): sqlstate = '53100' class OutOfMemoryError(InsufficientResourcesError): sqlstate = '53200' class TooManyConnectionsError(InsufficientResourcesError): sqlstate = '53300' class ConfigurationLimitExceededError(InsufficientResourcesError): sqlstate = '53400' class ProgramLimitExceededError(_base.PostgresError): sqlstate = '54000' class StatementTooComplexError(ProgramLimitExceededError): sqlstate = '54001' class TooManyColumnsError(ProgramLimitExceededError): sqlstate = '54011' class TooManyArgumentsError(ProgramLimitExceededError): sqlstate = '54023' class ObjectNotInPrerequisiteStateError(_base.PostgresError): sqlstate = '55000' class ObjectInUseError(ObjectNotInPrerequisiteStateError): sqlstate = '55006' class CantChangeRuntimeParamError(ObjectNotInPrerequisiteStateError): sqlstate = '55P02' class LockNotAvailableError(ObjectNotInPrerequisiteStateError): sqlstate = '55P03' class UnsafeNewEnumValueUsageError(ObjectNotInPrerequisiteStateError): sqlstate = '55P04' class OperatorInterventionError(_base.PostgresError): sqlstate = '57000' class QueryCanceledError(OperatorInterventionError): sqlstate = '57014' class AdminShutdownError(OperatorInterventionError): sqlstate = '57P01' class CrashShutdownError(OperatorInterventionError): sqlstate = '57P02' class CannotConnectNowError(OperatorInterventionError): sqlstate = '57P03' class DatabaseDroppedError(OperatorInterventionError): sqlstate = '57P04' class IdleSessionTimeoutError(OperatorInterventionError): sqlstate = '57P05' class PostgresSystemError(_base.PostgresError): sqlstate = '58000' class PostgresIOError(PostgresSystemError): sqlstate = '58030' class UndefinedFileError(PostgresSystemError): sqlstate = '58P01' class DuplicateFileError(PostgresSystemError): sqlstate = '58P02' class FileNameTooLongError(PostgresSystemError): sqlstate = '58P03' class SnapshotTooOldError(_base.PostgresError): sqlstate = '72000' class ConfigFileError(_base.PostgresError): sqlstate = 'F0000' class LockFileExistsError(ConfigFileError): sqlstate = 'F0001' class FDWError(_base.PostgresError): sqlstate = 'HV000' class FDWColumnNameNotFoundError(FDWError): sqlstate = 'HV005' class FDWDynamicParameterValueNeededError(FDWError): sqlstate = 'HV002' class FDWFunctionSequenceError(FDWError): sqlstate = 'HV010' class FDWInconsistentDescriptorInformationError(FDWError): sqlstate = 'HV021' class FDWInvalidAttributeValueError(FDWError): sqlstate = 'HV024' class FDWInvalidColumnNameError(FDWError): sqlstate = 'HV007' class FDWInvalidColumnNumberError(FDWError): sqlstate = 'HV008' class FDWInvalidDataTypeError(FDWError): sqlstate = 'HV004' class FDWInvalidDataTypeDescriptorsError(FDWError): sqlstate = 'HV006' class FDWInvalidDescriptorFieldIdentifierError(FDWError): sqlstate = 'HV091' class FDWInvalidHandleError(FDWError): sqlstate = 'HV00B' class FDWInvalidOptionIndexError(FDWError): sqlstate = 'HV00C' class FDWInvalidOptionNameError(FDWError): sqlstate = 'HV00D' class FDWInvalidStringLengthOrBufferLengthError(FDWError): sqlstate = 'HV090' class FDWInvalidStringFormatError(FDWError): sqlstate = 'HV00A' class FDWInvalidUseOfNullPointerError(FDWError): sqlstate = 'HV009' class FDWTooManyHandlesError(FDWError): sqlstate = 'HV014' class FDWOutOfMemoryError(FDWError): sqlstate = 'HV001' class FDWNoSchemasError(FDWError): sqlstate = 'HV00P' class FDWOptionNameNotFoundError(FDWError): sqlstate = 'HV00J' class FDWReplyHandleError(FDWError): sqlstate = 'HV00K' class FDWSchemaNotFoundError(FDWError): sqlstate = 'HV00Q' class FDWTableNotFoundError(FDWError): sqlstate = 'HV00R' class FDWUnableToCreateExecutionError(FDWError): sqlstate = 'HV00L' class FDWUnableToCreateReplyError(FDWError): sqlstate = 'HV00M' class FDWUnableToEstablishConnectionError(FDWError): sqlstate = 'HV00N' class PLPGSQLError(_base.PostgresError): sqlstate = 'P0000' class RaiseError(PLPGSQLError): sqlstate = 'P0001' class NoDataFoundError(PLPGSQLError): sqlstate = 'P0002' class TooManyRowsError(PLPGSQLError): sqlstate = 'P0003' class AssertError(PLPGSQLError): sqlstate = 'P0004' class InternalServerError(_base.PostgresError): sqlstate = 'XX000' class DataCorruptedError(InternalServerError): sqlstate = 'XX001' class IndexCorruptedError(InternalServerError): sqlstate = 'XX002' __all__ = ( 'ActiveSQLTransactionError', 'AdminShutdownError', 'AmbiguousAliasError', 'AmbiguousColumnError', 'AmbiguousFunctionError', 'AmbiguousParameterError', 'ArraySubscriptError', 'AssertError', 'BadCopyFileFormatError', 'BranchTransactionAlreadyActiveError', 'CannotCoerceError', 'CannotConnectNowError', 'CantChangeRuntimeParamError', 'CardinalityViolationError', 'CaseNotFoundError', 'CharacterNotInRepertoireError', 'CheckViolationError', 'ClientCannotConnectError', 'CollationMismatchError', 'ConfigFileError', 'ConfigurationLimitExceededError', 'ConnectionDoesNotExistError', 'ConnectionFailureError', 'ConnectionRejectionError', 'ContainingSQLNotPermittedError', 'CrashShutdownError', 'DataCorruptedError', 'DataError', 'DatabaseDroppedError', 'DatatypeMismatchError', 'DatetimeFieldOverflowError', 'DeadlockDetectedError', 'DependentObjectsStillExistError', 'DependentPrivilegeDescriptorsStillExistError', 'DeprecatedFeature', 'DiagnosticsError', 'DiskFullError', 'DivisionByZeroError', 'DuplicateAliasError', 'DuplicateColumnError', 'DuplicateCursorError', 'DuplicateDatabaseError', 'DuplicateFileError', 'DuplicateFunctionError', 'DuplicateJsonObjectKeyValueError', 'DuplicateObjectError', 'DuplicatePreparedStatementError', 'DuplicateSchemaError', 'DuplicateTableError', 'DynamicResultSetsReturned', 'ErrorInAssignmentError', 'EscapeCharacterConflictError', 'EventTriggerProtocolViolatedError', 'ExclusionViolationError', 'ExternalRoutineError', 'ExternalRoutineInvocationError', 'FDWColumnNameNotFoundError', 'FDWDynamicParameterValueNeededError', 'FDWError', 'FDWFunctionSequenceError', 'FDWInconsistentDescriptorInformationError', 'FDWInvalidAttributeValueError', 'FDWInvalidColumnNameError', 'FDWInvalidColumnNumberError', 'FDWInvalidDataTypeDescriptorsError', 'FDWInvalidDataTypeError', 'FDWInvalidDescriptorFieldIdentifierError', 'FDWInvalidHandleError', 'FDWInvalidOptionIndexError', 'FDWInvalidOptionNameError', 'FDWInvalidStringFormatError', 'FDWInvalidStringLengthOrBufferLengthError', 'FDWInvalidUseOfNullPointerError', 'FDWNoSchemasError', 'FDWOptionNameNotFoundError', 'FDWOutOfMemoryError', 'FDWReplyHandleError', 'FDWSchemaNotFoundError', 'FDWTableNotFoundError', 'FDWTooManyHandlesError', 'FDWUnableToCreateExecutionError', 'FDWUnableToCreateReplyError', 'FDWUnableToEstablishConnectionError', 'FeatureNotSupportedError', 'FileNameTooLongError', 'ForeignKeyViolationError', 'FunctionExecutedNoReturnStatementError', 'GeneratedAlwaysError', 'GroupingError', 'HeldCursorRequiresSameIsolationLevelError', 'IdleInTransactionSessionTimeoutError', 'IdleSessionTimeoutError', 'ImplicitZeroBitPadding', 'InFailedSQLTransactionError', 'InappropriateAccessModeForBranchTransactionError', 'InappropriateIsolationLevelForBranchTransactionError', 'IndeterminateCollationError', 'IndeterminateDatatypeError', 'IndexCorruptedError', 'IndicatorOverflowError', 'InsufficientPrivilegeError', 'InsufficientResourcesError', 'IntegrityConstraintViolationError', 'InternalServerError', 'IntervalFieldOverflowError', 'InvalidArgumentForLogarithmError', 'InvalidArgumentForNthValueFunctionError', 'InvalidArgumentForNtileFunctionError', 'InvalidArgumentForPowerFunctionError', 'InvalidArgumentForSQLJsonDatetimeFunctionError', 'InvalidArgumentForWidthBucketFunctionError', 'InvalidArgumentForXqueryError', 'InvalidAuthorizationSpecificationError', 'InvalidBinaryRepresentationError', 'InvalidCachedStatementError', 'InvalidCatalogNameError', 'InvalidCharacterValueForCastError', 'InvalidColumnDefinitionError', 'InvalidColumnReferenceError', 'InvalidCursorDefinitionError', 'InvalidCursorNameError', 'InvalidCursorStateError', 'InvalidDatabaseDefinitionError', 'InvalidDatetimeFormatError', 'InvalidEscapeCharacterError', 'InvalidEscapeOctetError', 'InvalidEscapeSequenceError', 'InvalidForeignKeyError', 'InvalidFunctionDefinitionError', 'InvalidGrantOperationError', 'InvalidGrantorError', 'InvalidIndicatorParameterValueError', 'InvalidJsonTextError', 'InvalidLocatorSpecificationError', 'InvalidNameError', 'InvalidObjectDefinitionError', 'InvalidParameterValueError', 'InvalidPasswordError', 'InvalidPrecedingOrFollowingSizeError', 'InvalidPreparedStatementDefinitionError', 'InvalidRecursionError', 'InvalidRegularExpressionError', 'InvalidRoleSpecificationError', 'InvalidRowCountInLimitClauseError', 'InvalidRowCountInResultOffsetClauseError', 'InvalidSQLJsonSubscriptError', 'InvalidSQLStatementNameError', 'InvalidSavepointSpecificationError', 'InvalidSchemaDefinitionError', 'InvalidSchemaNameError', 'InvalidSqlstateReturnedError', 'InvalidTableDefinitionError', 'InvalidTablesampleArgumentError', 'InvalidTablesampleRepeatError', 'InvalidTextRepresentationError', 'InvalidTimeZoneDisplacementValueError', 'InvalidTransactionInitiationError', 'InvalidTransactionStateError', 'InvalidTransactionTerminationError', 'InvalidUseOfEscapeCharacterError', 'InvalidXmlCommentError', 'InvalidXmlContentError', 'InvalidXmlDocumentError', 'InvalidXmlProcessingInstructionError', 'LocatorError', 'LockFileExistsError', 'LockNotAvailableError', 'ModifyingExternalRoutineSQLDataNotPermittedError', 'ModifyingSQLDataNotPermittedError', 'MoreThanOneSQLJsonItemError', 'MostSpecificTypeMismatchError', 'NameTooLongError', 'NoActiveSQLTransactionError', 'NoActiveSQLTransactionForBranchTransactionError', 'NoAdditionalDynamicResultSetsReturned', 'NoData', 'NoDataFoundError', 'NoSQLJsonItemError', 'NonNumericSQLJsonItemError', 'NonUniqueKeysInAJsonObjectError', 'NonstandardUseOfEscapeCharacterError', 'NotAnXmlDocumentError', 'NotNullViolationError', 'NullValueEliminatedInSetFunction', 'NullValueInExternalRoutineNotAllowedError', 'NullValueNoIndicatorParameterError', 'NullValueNotAllowedError', 'NumericValueOutOfRangeError', 'ObjectInUseError', 'ObjectNotInPrerequisiteStateError', 'OperatorInterventionError', 'OutOfMemoryError', 'PLPGSQLError', 'PostgresConnectionError', 'PostgresFloatingPointError', 'PostgresIOError', 'PostgresSyntaxError', 'PostgresSystemError', 'PostgresWarning', 'PrivilegeNotGranted', 'PrivilegeNotRevoked', 'ProgramLimitExceededError', 'ProhibitedExternalRoutineSQLStatementAttemptedError', 'ProhibitedSQLStatementAttemptedError', 'ProtocolViolationError', 'QueryCanceledError', 'RaiseError', 'ReadOnlySQLTransactionError', 'ReadingExternalRoutineSQLDataNotPermittedError', 'ReadingSQLDataNotPermittedError', 'ReservedNameError', 'RestrictViolationError', 'SQLJsonArrayNotFoundError', 'SQLJsonItemCannotBeCastToTargetTypeError', 'SQLJsonMemberNotFoundError', 'SQLJsonNumberNotFoundError', 'SQLJsonObjectNotFoundError', 'SQLJsonScalarRequiredError', 'SQLRoutineError', 'SQLStatementNotYetCompleteError', 'SavepointError', 'SchemaAndDataStatementMixingNotSupportedError', 'SequenceGeneratorLimitExceededError', 'SerializationError', 'SingletonSQLJsonItemRequiredError', 'SnapshotTooOldError', 'SrfProtocolViolatedError', 'StackedDiagnosticsAccessedWithoutActiveHandlerError', 'StatementCompletionUnknownError', 'StatementTooComplexError', 'StringDataLengthMismatchError', 'StringDataRightTruncation', 'StringDataRightTruncationError', 'SubstringError', 'SyntaxOrAccessError', 'TooManyArgumentsError', 'TooManyColumnsError', 'TooManyConnectionsError', 'TooManyJsonArrayElementsError', 'TooManyJsonObjectMembersError', 'TooManyRowsError', 'TransactionIntegrityConstraintViolationError', 'TransactionResolutionUnknownError', 'TransactionRollbackError', 'TransactionTimeoutError', 'TriggerProtocolViolatedError', 'TriggeredActionError', 'TriggeredDataChangeViolationError', 'TrimError', 'UndefinedColumnError', 'UndefinedFileError', 'UndefinedFunctionError', 'UndefinedObjectError', 'UndefinedParameterError', 'UndefinedTableError', 'UniqueViolationError', 'UnsafeNewEnumValueUsageError', 'UnterminatedCStringError', 'UntranslatableCharacterError', 'WindowingError', 'WithCheckOptionViolationError', 'WrongObjectTypeError', 'ZeroLengthCharacterStringError' ) __all__ += _base.__all__ ================================================ FILE: asyncpg/exceptions/_base.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncpg import sys import textwrap __all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', 'ClientConfigurationError', 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched', 'UnsupportedServerFeatureError') def _is_asyncpg_class(cls): modname = cls.__module__ return modname == 'asyncpg' or modname.startswith('asyncpg.') class PostgresMessageMeta(type): _message_map = {} _field_map = { 'S': 'severity', 'V': 'severity_en', 'C': 'sqlstate', 'M': 'message', 'D': 'detail', 'H': 'hint', 'P': 'position', 'p': 'internal_position', 'q': 'internal_query', 'W': 'context', 's': 'schema_name', 't': 'table_name', 'c': 'column_name', 'd': 'data_type_name', 'n': 'constraint_name', 'F': 'server_source_filename', 'L': 'server_source_line', 'R': 'server_source_function' } def __new__(mcls, name, bases, dct): cls = super().__new__(mcls, name, bases, dct) if cls.__module__ == mcls.__module__ and name == 'PostgresMessage': for f in mcls._field_map.values(): setattr(cls, f, None) if _is_asyncpg_class(cls): mod = sys.modules[cls.__module__] if hasattr(mod, name): raise RuntimeError('exception class redefinition: {}'.format( name)) code = dct.get('sqlstate') if code is not None: existing = mcls._message_map.get(code) if existing is not None: raise TypeError('{} has duplicate SQLSTATE code, which is' 'already defined by {}'.format( name, existing.__name__)) mcls._message_map[code] = cls return cls @classmethod def get_message_class_for_sqlstate(mcls, code): return mcls._message_map.get(code, UnknownPostgresError) class PostgresMessage(metaclass=PostgresMessageMeta): @classmethod def _get_error_class(cls, fields): sqlstate = fields.get('C') return type(cls).get_message_class_for_sqlstate(sqlstate) @classmethod def _get_error_dict(cls, fields, query): dct = { 'query': query } field_map = type(cls)._field_map for k, v in fields.items(): field = field_map.get(k) if field: dct[field] = v return dct @classmethod def _make_constructor(cls, fields, query=None): dct = cls._get_error_dict(fields, query) exccls = cls._get_error_class(fields) message = dct.get('message', '') # PostgreSQL will raise an exception when it detects # that the result type of the query has changed from # when the statement was prepared. # # The original error is somewhat cryptic and unspecific, # so we raise a custom subclass that is easier to handle # and identify. # # Note that we specifically do not rely on the error # message, as it is localizable. is_icse = ( exccls.__name__ == 'FeatureNotSupportedError' and _is_asyncpg_class(exccls) and dct.get('server_source_function') == 'RevalidateCachedQuery' ) if is_icse: exceptions = sys.modules[exccls.__module__] exccls = exceptions.InvalidCachedStatementError message = ('cached statement plan is invalid due to a database ' 'schema or configuration change') is_prepared_stmt_error = ( exccls.__name__ in ('DuplicatePreparedStatementError', 'InvalidSQLStatementNameError') and _is_asyncpg_class(exccls) ) if is_prepared_stmt_error: hint = dct.get('hint', '') hint += textwrap.dedent("""\ NOTE: pgbouncer with pool_mode set to "transaction" or "statement" does not support prepared statements properly. You have two options: * if you are using pgbouncer for connection pooling to a single server, switch to the connection pool functionality provided by asyncpg, it is a much better option for this purpose; * if you have no option of avoiding the use of pgbouncer, then you can set statement_cache_size to 0 when creating the asyncpg connection object. """) dct['hint'] = hint return exccls, message, dct def as_dict(self): dct = {} for f in type(self)._field_map.values(): val = getattr(self, f) if val is not None: dct[f] = val return dct class PostgresError(PostgresMessage, Exception): """Base class for all Postgres errors.""" def __str__(self): msg = self.args[0] if self.detail: msg += '\nDETAIL: {}'.format(self.detail) if self.hint: msg += '\nHINT: {}'.format(self.hint) return msg @classmethod def new(cls, fields, query=None): exccls, message, dct = cls._make_constructor(fields, query) ex = exccls(message) ex.__dict__.update(dct) return ex class FatalPostgresError(PostgresError): """A fatal error that should result in server disconnection.""" class UnknownPostgresError(FatalPostgresError): """An error with an unknown SQLSTATE code.""" class InterfaceMessage: def __init__(self, *, detail=None, hint=None): self.detail = detail self.hint = hint def __str__(self): msg = self.args[0] if self.detail: msg += '\nDETAIL: {}'.format(self.detail) if self.hint: msg += '\nHINT: {}'.format(self.hint) return msg class InterfaceError(InterfaceMessage, Exception): """An error caused by improper use of asyncpg API.""" def __init__(self, msg, *, detail=None, hint=None): InterfaceMessage.__init__(self, detail=detail, hint=hint) Exception.__init__(self, msg) def with_msg(self, msg): return type(self)( msg, detail=self.detail, hint=self.hint, ).with_traceback( self.__traceback__ ) class ClientConfigurationError(InterfaceError, ValueError): """An error caused by improper client configuration.""" class DataError(InterfaceError, ValueError): """An error caused by invalid query input.""" class UnsupportedClientFeatureError(InterfaceError): """Requested feature is unsupported by asyncpg.""" class UnsupportedServerFeatureError(InterfaceError): """Requested feature is unsupported by PostgreSQL server.""" class InterfaceWarning(InterfaceMessage, UserWarning): """A warning caused by an improper use of asyncpg API.""" def __init__(self, msg, *, detail=None, hint=None): InterfaceMessage.__init__(self, detail=detail, hint=hint) UserWarning.__init__(self, msg) class InternalClientError(Exception): """All unexpected errors not classified otherwise.""" class ProtocolError(InternalClientError): """Unexpected condition in the handling of PostgreSQL protocol input.""" class TargetServerAttributeNotMatched(InternalClientError): """Could not find a host that satisfies the target attribute requirement""" class OutdatedSchemaCacheError(InternalClientError): """A value decoding error caused by a schema change before row fetching.""" def __init__(self, msg, *, schema=None, data_type=None, position=None): super().__init__(msg) self.schema_name = schema self.data_type_name = data_type self.position = position class PostgresLogMessage(PostgresMessage): """A base class for non-error server messages.""" def __str__(self): return '{}: {}'.format(type(self).__name__, self.message) def __setattr__(self, name, val): raise TypeError('instances of {} are immutable'.format( type(self).__name__)) @classmethod def new(cls, fields, query=None): exccls, message_text, dct = cls._make_constructor(fields, query) if exccls is UnknownPostgresError: exccls = PostgresLogMessage if exccls is PostgresLogMessage: severity = dct.get('severity_en') or dct.get('severity') if severity and severity.upper() == 'WARNING': exccls = asyncpg.PostgresWarning if issubclass(exccls, (BaseException, Warning)): msg = exccls(message_text) else: msg = exccls() msg.__dict__.update(dct) return msg ================================================ FILE: asyncpg/introspection.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import typing from .protocol.protocol import _create_record # type: ignore if typing.TYPE_CHECKING: from . import protocol _TYPEINFO_13: typing.Final = '''\ ( SELECT t.oid AS oid, ns.nspname AS ns, t.typname AS name, t.typtype AS kind, (CASE WHEN t.typtype = 'd' THEN (WITH RECURSIVE typebases(oid, depth) AS ( SELECT t2.typbasetype AS oid, 0 AS depth FROM pg_type t2 WHERE t2.oid = t.oid UNION ALL SELECT t2.typbasetype AS oid, tb.depth + 1 AS depth FROM pg_type t2, typebases tb WHERE tb.oid = t2.oid AND t2.typbasetype != 0 ) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1) ELSE NULL END) AS basetype, t.typelem AS elemtype, elem_t.typdelim AS elemdelim, range_t.rngsubtype AS range_subtype, (CASE WHEN t.typtype = 'c' THEN (SELECT array_agg(ia.atttypid ORDER BY ia.attnum) FROM pg_attribute ia INNER JOIN pg_class c ON (ia.attrelid = c.oid) WHERE ia.attnum > 0 AND NOT ia.attisdropped AND c.reltype = t.oid) ELSE NULL END) AS attrtypoids, (CASE WHEN t.typtype = 'c' THEN (SELECT array_agg(ia.attname::text ORDER BY ia.attnum) FROM pg_attribute ia INNER JOIN pg_class c ON (ia.attrelid = c.oid) WHERE ia.attnum > 0 AND NOT ia.attisdropped AND c.reltype = t.oid) ELSE NULL END) AS attrnames FROM pg_catalog.pg_type AS t INNER JOIN pg_catalog.pg_namespace ns ON ( ns.oid = t.typnamespace) LEFT JOIN pg_type elem_t ON ( t.typlen = -1 AND t.typelem != 0 AND t.typelem = elem_t.oid ) LEFT JOIN pg_range range_t ON ( t.oid = range_t.rngtypid ) ) ''' INTRO_LOOKUP_TYPES_13 = '''\ WITH RECURSIVE typeinfo_tree( oid, ns, name, kind, basetype, elemtype, elemdelim, range_subtype, attrtypoids, attrnames, depth) AS ( SELECT ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.elemtype, ti.elemdelim, ti.range_subtype, ti.attrtypoids, ti.attrnames, 0 FROM {typeinfo} AS ti WHERE ti.oid = any($1::oid[]) UNION ALL SELECT ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.elemtype, ti.elemdelim, ti.range_subtype, ti.attrtypoids, ti.attrnames, tt.depth + 1 FROM {typeinfo} ti, typeinfo_tree tt WHERE (tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype) OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids)) OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype) OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype) ) SELECT DISTINCT *, basetype::regtype::text AS basetype_name, elemtype::regtype::text AS elemtype_name, range_subtype::regtype::text AS range_subtype_name FROM typeinfo_tree ORDER BY depth DESC '''.format(typeinfo=_TYPEINFO_13) _TYPEINFO: typing.Final = '''\ ( SELECT t.oid AS oid, ns.nspname AS ns, t.typname AS name, t.typtype AS kind, (CASE WHEN t.typtype = 'd' THEN (WITH RECURSIVE typebases(oid, depth) AS ( SELECT t2.typbasetype AS oid, 0 AS depth FROM pg_type t2 WHERE t2.oid = t.oid UNION ALL SELECT t2.typbasetype AS oid, tb.depth + 1 AS depth FROM pg_type t2, typebases tb WHERE tb.oid = t2.oid AND t2.typbasetype != 0 ) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1) ELSE NULL END) AS basetype, t.typelem AS elemtype, elem_t.typdelim AS elemdelim, COALESCE( range_t.rngsubtype, multirange_t.rngsubtype) AS range_subtype, (CASE WHEN t.typtype = 'c' THEN (SELECT array_agg(ia.atttypid ORDER BY ia.attnum) FROM pg_attribute ia INNER JOIN pg_class c ON (ia.attrelid = c.oid) WHERE ia.attnum > 0 AND NOT ia.attisdropped AND c.reltype = t.oid) ELSE NULL END) AS attrtypoids, (CASE WHEN t.typtype = 'c' THEN (SELECT array_agg(ia.attname::text ORDER BY ia.attnum) FROM pg_attribute ia INNER JOIN pg_class c ON (ia.attrelid = c.oid) WHERE ia.attnum > 0 AND NOT ia.attisdropped AND c.reltype = t.oid) ELSE NULL END) AS attrnames FROM pg_catalog.pg_type AS t INNER JOIN pg_catalog.pg_namespace ns ON ( ns.oid = t.typnamespace) LEFT JOIN pg_type elem_t ON ( t.typlen = -1 AND t.typelem != 0 AND t.typelem = elem_t.oid ) LEFT JOIN pg_range range_t ON ( t.oid = range_t.rngtypid ) LEFT JOIN pg_range multirange_t ON ( t.oid = multirange_t.rngmultitypid ) ) ''' INTRO_LOOKUP_TYPES = '''\ WITH RECURSIVE typeinfo_tree( oid, ns, name, kind, basetype, elemtype, elemdelim, range_subtype, attrtypoids, attrnames, depth) AS ( SELECT ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.elemtype, ti.elemdelim, ti.range_subtype, ti.attrtypoids, ti.attrnames, 0 FROM {typeinfo} AS ti WHERE ti.oid = any($1::oid[]) UNION ALL SELECT ti.oid, ti.ns, ti.name, ti.kind, ti.basetype, ti.elemtype, ti.elemdelim, ti.range_subtype, ti.attrtypoids, ti.attrnames, tt.depth + 1 FROM {typeinfo} ti, typeinfo_tree tt WHERE (tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype) OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids)) OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype) OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype) ) SELECT DISTINCT *, basetype::regtype::text AS basetype_name, elemtype::regtype::text AS elemtype_name, range_subtype::regtype::text AS range_subtype_name FROM typeinfo_tree ORDER BY depth DESC '''.format(typeinfo=_TYPEINFO) TYPE_BY_NAME: typing.Final = '''\ SELECT t.oid, t.typelem AS elemtype, t.typtype AS kind FROM pg_catalog.pg_type AS t INNER JOIN pg_catalog.pg_namespace ns ON (ns.oid = t.typnamespace) WHERE t.typname = $1 AND ns.nspname = $2 ''' def TypeRecord( rec: typing.Tuple[int, typing.Optional[int], bytes], ) -> protocol.Record: assert len(rec) == 3 return _create_record( # type: ignore {"oid": 0, "elemtype": 1, "kind": 2}, rec) # 'b' for a base type, 'd' for a domain, 'e' for enum. SCALAR_TYPE_KINDS = (b'b', b'd', b'e') def is_scalar_type(typeinfo: protocol.Record) -> bool: return ( typeinfo['kind'] in SCALAR_TYPE_KINDS and not typeinfo['elemtype'] ) def is_domain_type(typeinfo: protocol.Record) -> bool: return typeinfo['kind'] == b'd' # type: ignore[no-any-return] def is_composite_type(typeinfo: protocol.Record) -> bool: return typeinfo['kind'] == b'c' # type: ignore[no-any-return] ================================================ FILE: asyncpg/pool.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable import functools import inspect import logging import time from types import TracebackType from typing import Any, Optional, Type import warnings from . import compat from . import connection from . import exceptions from . import protocol logger = logging.getLogger(__name__) class PoolConnectionProxyMeta(type): def __new__( mcls, name: str, bases: tuple[Type[Any], ...], dct: dict[str, Any], *, wrap: bool = False, ) -> PoolConnectionProxyMeta: if wrap: for attrname in dir(connection.Connection): if attrname.startswith('_') or attrname in dct: continue meth = getattr(connection.Connection, attrname) if not inspect.isfunction(meth): continue iscoroutine = inspect.iscoroutinefunction(meth) wrapper = mcls._wrap_connection_method(attrname, iscoroutine) wrapper = functools.update_wrapper(wrapper, meth) dct[attrname] = wrapper if '__doc__' not in dct: dct['__doc__'] = connection.Connection.__doc__ return super().__new__(mcls, name, bases, dct) @staticmethod def _wrap_connection_method( meth_name: str, iscoroutine: bool ) -> Callable[..., Any]: def call_con_method(self: Any, *args: Any, **kwargs: Any) -> Any: # This method will be owned by PoolConnectionProxy class. if self._con is None: raise exceptions.InterfaceError( 'cannot call Connection.{}(): ' 'connection has been released back to the pool'.format( meth_name)) meth = getattr(self._con.__class__, meth_name) return meth(self._con, *args, **kwargs) if iscoroutine: compat.markcoroutinefunction(call_con_method) return call_con_method class PoolConnectionProxy(connection._ConnectionProxy, metaclass=PoolConnectionProxyMeta, wrap=True): __slots__ = ('_con', '_holder') def __init__( self, holder: PoolConnectionHolder, con: connection.Connection ) -> None: self._con = con self._holder = holder con._set_proxy(self) def __getattr__(self, attr: str) -> Any: # Proxy all unresolved attributes to the wrapped Connection object. return getattr(self._con, attr) def _detach(self) -> Optional[connection.Connection]: if self._con is None: return con, self._con = self._con, None con._set_proxy(None) return con def __repr__(self) -> str: if self._con is None: return '<{classname} [released] {id:#x}>'.format( classname=self.__class__.__name__, id=id(self)) else: return '<{classname} {con!r} {id:#x}>'.format( classname=self.__class__.__name__, con=self._con, id=id(self)) class PoolConnectionHolder: __slots__ = ('_con', '_pool', '_loop', '_proxy', '_max_queries', '_setup', '_max_inactive_time', '_in_use', '_inactive_callback', '_timeout', '_generation') def __init__( self, pool: "Pool", *, max_queries: float, setup: Optional[Callable[[PoolConnectionProxy], Awaitable[None]]], max_inactive_time: float, ) -> None: self._pool = pool self._con: Optional[connection.Connection] = None self._proxy: Optional[PoolConnectionProxy] = None self._max_queries = max_queries self._max_inactive_time = max_inactive_time self._setup = setup self._inactive_callback: Optional[Callable] = None self._in_use: Optional[asyncio.Future] = None self._timeout: Optional[float] = None self._generation: Optional[int] = None def is_connected(self) -> bool: return self._con is not None and not self._con.is_closed() def is_idle(self) -> bool: return not self._in_use async def connect(self) -> None: if self._con is not None: raise exceptions.InternalClientError( 'PoolConnectionHolder.connect() called while another ' 'connection already exists') self._con = await self._pool._get_new_connection() self._generation = self._pool._generation self._maybe_cancel_inactive_callback() self._setup_inactive_callback() async def acquire(self) -> PoolConnectionProxy: if self._con is None or self._con.is_closed(): self._con = None await self.connect() elif self._generation != self._pool._generation: # Connections have been expired, re-connect the holder. self._pool._loop.create_task( self._con.close(timeout=self._timeout)) self._con = None await self.connect() self._maybe_cancel_inactive_callback() self._proxy = proxy = PoolConnectionProxy(self, self._con) if self._setup is not None: try: await self._setup(proxy) except (Exception, asyncio.CancelledError) as ex: # If a user-defined `setup` function fails, we don't # know if the connection is safe for re-use, hence # we close it. A new connection will be created # when `acquire` is called again. try: # Use `close()` to close the connection gracefully. # An exception in `setup` isn't necessarily caused # by an IO or a protocol error. close() will # do the necessary cleanup via _release_on_close(). await self._con.close() finally: raise ex self._in_use = self._pool._loop.create_future() return proxy async def release(self, timeout: Optional[float]) -> None: if self._in_use is None: raise exceptions.InternalClientError( 'PoolConnectionHolder.release() called on ' 'a free connection holder') if self._con.is_closed(): # When closing, pool connections perform the necessary # cleanup, so we don't have to do anything else here. return self._timeout = None if self._con._protocol.queries_count >= self._max_queries: # The connection has reached its maximum utilization limit, # so close it. Connection.close() will call _release(). await self._con.close(timeout=timeout) return if self._generation != self._pool._generation: # The connection has expired because it belongs to # an older generation (Pool.expire_connections() has # been called.) await self._con.close(timeout=timeout) return try: budget = timeout if self._con._protocol._is_cancelling(): # If the connection is in cancellation state, # wait for the cancellation started = time.monotonic() await compat.wait_for( self._con._protocol._wait_for_cancellation(), budget) if budget is not None: budget -= time.monotonic() - started if self._pool._reset is not None: async with compat.timeout(budget): await self._con._reset() await self._pool._reset(self._con) else: await self._con.reset(timeout=budget) except (Exception, asyncio.CancelledError) as ex: # If the `reset` call failed, terminate the connection. # A new one will be created when `acquire` is called # again. try: # An exception in `reset` is most likely caused by # an IO error, so terminate the connection. self._con.terminate() finally: raise ex # Free this connection holder and invalidate the # connection proxy. self._release() # Rearm the connection inactivity timer. self._setup_inactive_callback() async def wait_until_released(self) -> None: if self._in_use is None: return else: await self._in_use async def close(self) -> None: if self._con is not None: # Connection.close() will call _release_on_close() to # finish holder cleanup. await self._con.close() def terminate(self) -> None: if self._con is not None: # Connection.terminate() will call _release_on_close() to # finish holder cleanup. self._con.terminate() def _setup_inactive_callback(self) -> None: if self._inactive_callback is not None: raise exceptions.InternalClientError( 'pool connection inactivity timer already exists') if self._max_inactive_time: self._inactive_callback = self._pool._loop.call_later( self._max_inactive_time, self._deactivate_inactive_connection) def _maybe_cancel_inactive_callback(self) -> None: if self._inactive_callback is not None: self._inactive_callback.cancel() self._inactive_callback = None def _deactivate_inactive_connection(self) -> None: if self._in_use is not None: raise exceptions.InternalClientError( 'attempting to deactivate an acquired connection') if self._con is not None: # The connection is idle and not in use, so it's fine to # use terminate() instead of close(). self._con.terminate() # Must call clear_connection, because _deactivate_connection # is called when the connection is *not* checked out, and # so terminate() above will not call the below. self._release_on_close() def _release_on_close(self) -> None: self._maybe_cancel_inactive_callback() self._release() self._con = None def _release(self) -> None: """Release this connection holder.""" if self._in_use is None: # The holder is not checked out. return if not self._in_use.done(): self._in_use.set_result(None) self._in_use = None # Deinitialize the connection proxy. All subsequent # operations on it will fail. if self._proxy is not None: self._proxy._detach() self._proxy = None # Put ourselves back to the pool queue. self._pool._queue.put_nowait(self) class Pool: """A connection pool. Connection pool can be used to manage a set of connections to the database. Connections are first acquired from the pool, then used, and then released back to the pool. Once a connection is released, it's reset to close all open cursors and other resources *except* prepared statements. Pools are created by calling :func:`~asyncpg.pool.create_pool`. """ __slots__ = ( '_queue', '_loop', '_minsize', '_maxsize', '_init', '_connect', '_reset', '_connect_args', '_connect_kwargs', '_holders', '_initialized', '_initializing', '_closing', '_closed', '_connection_class', '_record_class', '_generation', '_setup', '_max_queries', '_max_inactive_connection_lifetime' ) def __init__(self, *connect_args, min_size, max_size, max_queries, max_inactive_connection_lifetime, connect=None, setup=None, init=None, reset=None, loop, connection_class, record_class, **connect_kwargs): if len(connect_args) > 1: warnings.warn( "Passing multiple positional arguments to asyncpg.Pool " "constructor is deprecated and will be removed in " "asyncpg 0.17.0. The non-deprecated form is " "asyncpg.Pool(, **kwargs)", DeprecationWarning, stacklevel=2) if loop is None: loop = asyncio.get_event_loop() self._loop = loop if max_size <= 0: raise ValueError('max_size is expected to be greater than zero') if min_size < 0: raise ValueError( 'min_size is expected to be greater or equal to zero') if min_size > max_size: raise ValueError('min_size is greater than max_size') if max_queries <= 0: raise ValueError('max_queries is expected to be greater than zero') if max_inactive_connection_lifetime < 0: raise ValueError( 'max_inactive_connection_lifetime is expected to be greater ' 'or equal to zero') if not issubclass(connection_class, connection.Connection): raise TypeError( 'connection_class is expected to be a subclass of ' 'asyncpg.Connection, got {!r}'.format(connection_class)) if not issubclass(record_class, protocol.Record): raise TypeError( 'record_class is expected to be a subclass of ' 'asyncpg.Record, got {!r}'.format(record_class)) self._minsize = min_size self._maxsize = max_size self._holders = [] self._initialized = False self._initializing = False self._queue = None self._connection_class = connection_class self._record_class = record_class self._closing = False self._closed = False self._generation = 0 self._connect = connect if connect is not None else connection.connect self._connect_args = connect_args self._connect_kwargs = connect_kwargs self._setup = setup self._init = init self._reset = reset self._max_queries = max_queries self._max_inactive_connection_lifetime = \ max_inactive_connection_lifetime async def _async__init__(self): if self._initialized: return self if self._initializing: raise exceptions.InterfaceError( 'pool is being initialized in another task') if self._closed: raise exceptions.InterfaceError('pool is closed') self._initializing = True try: await self._initialize() return self finally: self._initializing = False self._initialized = True async def _initialize(self): self._queue = asyncio.LifoQueue(maxsize=self._maxsize) for _ in range(self._maxsize): ch = PoolConnectionHolder( self, max_queries=self._max_queries, max_inactive_time=self._max_inactive_connection_lifetime, setup=self._setup) self._holders.append(ch) self._queue.put_nowait(ch) if self._minsize: # Since we use a LIFO queue, the first items in the queue will be # the last ones in `self._holders`. We want to pre-connect the # first few connections in the queue, therefore we want to walk # `self._holders` in reverse. # Connect the first connection holder in the queue so that # any connection issues are visible early. first_ch = self._holders[-1] # type: PoolConnectionHolder await first_ch.connect() if self._minsize > 1: connect_tasks = [] for i, ch in enumerate(reversed(self._holders[:-1])): # `minsize - 1` because we already have first_ch if i >= self._minsize - 1: break connect_tasks.append(ch.connect()) await asyncio.gather(*connect_tasks) def is_closing(self): """Return ``True`` if the pool is closing or is closed. .. versionadded:: 0.28.0 """ return self._closed or self._closing def get_size(self): """Return the current number of connections in this pool. .. versionadded:: 0.25.0 """ return sum(h.is_connected() for h in self._holders) def get_min_size(self): """Return the minimum number of connections in this pool. .. versionadded:: 0.25.0 """ return self._minsize def get_max_size(self): """Return the maximum allowed number of connections in this pool. .. versionadded:: 0.25.0 """ return self._maxsize def get_idle_size(self): """Return the current number of idle connections in this pool. .. versionadded:: 0.25.0 """ return sum(h.is_connected() and h.is_idle() for h in self._holders) def set_connect_args(self, dsn=None, **connect_kwargs): r"""Set the new connection arguments for this pool. The new connection arguments will be used for all subsequent new connection attempts. Existing connections will remain until they expire. Use :meth:`Pool.expire_connections() ` to expedite the connection expiry. :param str dsn: Connection arguments specified using as a single string in the following format: ``postgres://user:pass@host:port/database?option=value``. :param \*\*connect_kwargs: Keyword arguments for the :func:`~asyncpg.connection.connect` function. .. versionadded:: 0.16.0 """ self._connect_args = [dsn] self._connect_kwargs = connect_kwargs async def _get_new_connection(self): con = await self._connect( *self._connect_args, loop=self._loop, connection_class=self._connection_class, record_class=self._record_class, **self._connect_kwargs, ) if not isinstance(con, self._connection_class): good = self._connection_class good_n = f'{good.__module__}.{good.__name__}' bad = type(con) if bad.__module__ == "builtins": bad_n = bad.__name__ else: bad_n = f'{bad.__module__}.{bad.__name__}' raise exceptions.InterfaceError( "expected pool connect callback to return an instance of " f"'{good_n}', got " f"'{bad_n}'" ) if self._init is not None: try: await self._init(con) except (Exception, asyncio.CancelledError) as ex: # If a user-defined `init` function fails, we don't # know if the connection is safe for re-use, hence # we close it. A new connection will be created # when `acquire` is called again. try: # Use `close()` to close the connection gracefully. # An exception in `init` isn't necessarily caused # by an IO or a protocol error. close() will # do the necessary cleanup via _release_on_close(). await con.close() finally: raise ex return con async def execute( self, query: str, *args, timeout: Optional[float]=None, ) -> str: """Execute an SQL command (or commands). Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.execute() `. .. versionadded:: 0.10.0 """ async with self.acquire() as con: return await con.execute(query, *args, timeout=timeout) async def executemany( self, command: str, args, *, timeout: Optional[float]=None, ): """Execute an SQL *command* for each sequence of arguments in *args*. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.executemany() `. .. versionadded:: 0.10.0 """ async with self.acquire() as con: return await con.executemany(command, args, timeout=timeout) async def fetch( self, query, *args, timeout=None, record_class=None ) -> list: """Run a query and return the results as a list of :class:`Record`. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.fetch() `. .. versionadded:: 0.10.0 """ async with self.acquire() as con: return await con.fetch( query, *args, timeout=timeout, record_class=record_class ) async def fetchval(self, query, *args, column=0, timeout=None): """Run a query and return a value in the first row. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.fetchval() `. .. versionadded:: 0.10.0 """ async with self.acquire() as con: return await con.fetchval( query, *args, column=column, timeout=timeout) async def fetchrow(self, query, *args, timeout=None, record_class=None): """Run a query and return the first row. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.fetchrow() `. .. versionadded:: 0.10.0 """ async with self.acquire() as con: return await con.fetchrow( query, *args, timeout=timeout, record_class=record_class ) async def fetchmany(self, query, args, *, timeout=None, record_class=None): """Run a query for each sequence of arguments in *args* and return the results as a list of :class:`Record`. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.fetchmany() `. .. versionadded:: 0.30.0 """ async with self.acquire() as con: return await con.fetchmany( query, args, timeout=timeout, record_class=record_class ) async def copy_from_table( self, table_name, *, output, columns=None, schema_name=None, timeout=None, format=None, oids=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, encoding=None ): """Copy table contents to a file or file-like object. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.copy_from_table() `. .. versionadded:: 0.24.0 """ async with self.acquire() as con: return await con.copy_from_table( table_name, output=output, columns=columns, schema_name=schema_name, timeout=timeout, format=format, oids=oids, delimiter=delimiter, null=null, header=header, quote=quote, escape=escape, force_quote=force_quote, encoding=encoding ) async def copy_from_query( self, query, *args, output, timeout=None, format=None, oids=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, encoding=None ): """Copy the results of a query to a file or file-like object. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.copy_from_query() `. .. versionadded:: 0.24.0 """ async with self.acquire() as con: return await con.copy_from_query( query, *args, output=output, timeout=timeout, format=format, oids=oids, delimiter=delimiter, null=null, header=header, quote=quote, escape=escape, force_quote=force_quote, encoding=encoding ) async def copy_to_table( self, table_name, *, source, columns=None, schema_name=None, timeout=None, format=None, oids=None, freeze=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, force_not_null=None, force_null=None, encoding=None, where=None ): """Copy data to the specified table. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.copy_to_table() `. .. versionadded:: 0.24.0 """ async with self.acquire() as con: return await con.copy_to_table( table_name, source=source, columns=columns, schema_name=schema_name, timeout=timeout, format=format, oids=oids, freeze=freeze, delimiter=delimiter, null=null, header=header, quote=quote, escape=escape, force_quote=force_quote, force_not_null=force_not_null, force_null=force_null, encoding=encoding, where=where ) async def copy_records_to_table( self, table_name, *, records, columns=None, schema_name=None, timeout=None, where=None ): """Copy a list of records to the specified table using binary COPY. Pool performs this operation using one of its connections. Other than that, it behaves identically to :meth:`Connection.copy_records_to_table() `. .. versionadded:: 0.24.0 """ async with self.acquire() as con: return await con.copy_records_to_table( table_name, records=records, columns=columns, schema_name=schema_name, timeout=timeout, where=where ) def acquire(self, *, timeout=None): """Acquire a database connection from the pool. :param float timeout: A timeout for acquiring a Connection. :return: An instance of :class:`~asyncpg.connection.Connection`. Can be used in an ``await`` expression or with an ``async with`` block. .. code-block:: python async with pool.acquire() as con: await con.execute(...) Or: .. code-block:: python con = await pool.acquire() try: await con.execute(...) finally: await pool.release(con) """ return PoolAcquireContext(self, timeout) async def _acquire(self, timeout): async def _acquire_impl(): ch = await self._queue.get() # type: PoolConnectionHolder try: proxy = await ch.acquire() # type: PoolConnectionProxy except (Exception, asyncio.CancelledError): self._queue.put_nowait(ch) raise else: # Record the timeout, as we will apply it by default # in release(). ch._timeout = timeout return proxy if self._closing: raise exceptions.InterfaceError('pool is closing') self._check_init() if timeout is None: return await _acquire_impl() else: return await compat.wait_for( _acquire_impl(), timeout=timeout) async def release(self, connection, *, timeout=None): """Release a database connection back to the pool. :param Connection connection: A :class:`~asyncpg.connection.Connection` object to release. :param float timeout: A timeout for releasing the connection. If not specified, defaults to the timeout provided in the corresponding call to the :meth:`Pool.acquire() ` method. .. versionchanged:: 0.14.0 Added the *timeout* parameter. """ if (type(connection) is not PoolConnectionProxy or connection._holder._pool is not self): raise exceptions.InterfaceError( 'Pool.release() received invalid connection: ' '{connection!r} is not a member of this pool'.format( connection=connection)) if connection._con is None: # Already released, do nothing. return self._check_init() # Let the connection do its internal housekeeping when its released. connection._con._on_release() ch = connection._holder if timeout is None: timeout = ch._timeout # Use asyncio.shield() to guarantee that task cancellation # does not prevent the connection from being returned to the # pool properly. return await asyncio.shield(ch.release(timeout)) async def close(self): """Attempt to gracefully close all connections in the pool. Wait until all pool connections are released, close them and shut down the pool. If any error (including cancellation) occurs in ``close()`` the pool will terminate by calling :meth:`Pool.terminate() `. It is advisable to use :func:`python:asyncio.wait_for` to set a timeout. .. versionchanged:: 0.16.0 ``close()`` now waits until all pool connections are released before closing them and the pool. Errors raised in ``close()`` will cause immediate pool termination. """ if self._closed: return self._check_init() self._closing = True warning_callback = None try: warning_callback = self._loop.call_later( 60, self._warn_on_long_close) release_coros = [ ch.wait_until_released() for ch in self._holders] await asyncio.gather(*release_coros) close_coros = [ ch.close() for ch in self._holders] await asyncio.gather(*close_coros) except (Exception, asyncio.CancelledError): self.terminate() raise finally: if warning_callback is not None: warning_callback.cancel() self._closed = True self._closing = False def _warn_on_long_close(self): logger.warning('Pool.close() is taking over 60 seconds to complete. ' 'Check if you have any unreleased connections left. ' 'Use asyncio.wait_for() to set a timeout for ' 'Pool.close().') def terminate(self): """Terminate all connections in the pool.""" if self._closed: return self._check_init() for ch in self._holders: ch.terminate() self._closed = True async def expire_connections(self): """Expire all currently open connections. Cause all currently open connections to get replaced on the next :meth:`~asyncpg.pool.Pool.acquire()` call. .. versionadded:: 0.16.0 """ self._generation += 1 def _check_init(self): if not self._initialized: if self._initializing: raise exceptions.InterfaceError( 'pool is being initialized, but not yet ready: ' 'likely there is a race between creating a pool and ' 'using it') raise exceptions.InterfaceError('pool is not initialized') if self._closed: raise exceptions.InterfaceError('pool is closed') def _drop_statement_cache(self): # Drop statement cache for all connections in the pool. for ch in self._holders: if ch._con is not None: ch._con._drop_local_statement_cache() def _drop_type_cache(self): # Drop type codec cache for all connections in the pool. for ch in self._holders: if ch._con is not None: ch._con._drop_local_type_cache() def __await__(self): return self._async__init__().__await__() async def __aenter__(self): await self._async__init__() return self async def __aexit__(self, *exc): await self.close() class PoolAcquireContext: __slots__ = ('timeout', 'connection', 'done', 'pool') def __init__(self, pool: Pool, timeout: Optional[float]) -> None: self.pool = pool self.timeout = timeout self.connection = None self.done = False async def __aenter__(self): if self.connection is not None or self.done: raise exceptions.InterfaceError('a connection is already acquired') self.connection = await self.pool._acquire(self.timeout) return self.connection async def __aexit__( self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None, exc_tb: Optional[TracebackType] = None, ) -> None: self.done = True con = self.connection self.connection = None await self.pool.release(con) def __await__(self): self.done = True return self.pool._acquire(self.timeout).__await__() def create_pool(dsn=None, *, min_size=10, max_size=10, max_queries=50000, max_inactive_connection_lifetime=300.0, connect=None, setup=None, init=None, reset=None, loop=None, connection_class=connection.Connection, record_class=protocol.Record, **connect_kwargs): r"""Create a connection pool. Can be used either with an ``async with`` block: .. code-block:: python async with asyncpg.create_pool(user='postgres', command_timeout=60) as pool: await pool.fetch('SELECT 1') Or to perform multiple operations on a single connection: .. code-block:: python async with asyncpg.create_pool(user='postgres', command_timeout=60) as pool: async with pool.acquire() as con: await con.execute(''' CREATE TABLE names ( id serial PRIMARY KEY, name VARCHAR (255) NOT NULL) ''') await con.fetch('SELECT 1') Or directly with ``await`` (not recommended): .. code-block:: python pool = await asyncpg.create_pool(user='postgres', command_timeout=60) con = await pool.acquire() try: await con.fetch('SELECT 1') finally: await pool.release(con) .. warning:: Prepared statements and cursors returned by :meth:`Connection.prepare() ` and :meth:`Connection.cursor() ` become invalid once the connection is released. Likewise, all notification and log listeners are removed, and ``asyncpg`` will issue a warning if there are any listener callbacks registered on a connection that is being released to the pool. :param str dsn: Connection arguments specified using as a single string in the following format: ``postgres://user:pass@host:port/database?option=value``. :param \*\*connect_kwargs: Keyword arguments for the :func:`~asyncpg.connection.connect` function. :param Connection connection_class: The class to use for connections. Must be a subclass of :class:`~asyncpg.connection.Connection`. :param type record_class: If specified, the class to use for records returned by queries on the connections in this pool. Must be a subclass of :class:`~asyncpg.Record`. :param int min_size: Number of connection the pool will be initialized with. :param int max_size: Max number of connections in the pool. :param int max_queries: Number of queries after a connection is closed and replaced with a new connection. :param float max_inactive_connection_lifetime: Number of seconds after which inactive connections in the pool will be closed. Pass ``0`` to disable this mechanism. :param coroutine connect: A coroutine that is called instead of :func:`~asyncpg.connection.connect` whenever the pool needs to make a new connection. Must return an instance of type specified by *connection_class* or :class:`~asyncpg.connection.Connection` if *connection_class* was not specified. :param coroutine setup: A coroutine to prepare a connection right before it is returned from :meth:`Pool.acquire()`. An example use case would be to automatically set up notifications listeners for all connections of a pool. :param coroutine init: A coroutine to initialize a connection when it is created. An example use case would be to setup type codecs with :meth:`Connection.set_builtin_type_codec() <\ asyncpg.connection.Connection.set_builtin_type_codec>` or :meth:`Connection.set_type_codec() <\ asyncpg.connection.Connection.set_type_codec>`. :param coroutine reset: A coroutine to reset a connection before it is returned to the pool by :meth:`Pool.release()`. The function is supposed to reset any changes made to the database session so that the next acquirer gets the connection in a well-defined state. The default implementation calls :meth:`Connection.reset() <\ asyncpg.connection.Connection.reset>`, which runs the following:: SELECT pg_advisory_unlock_all(); CLOSE ALL; UNLISTEN *; RESET ALL; The exact reset query is determined by detected server capabilities, and a custom *reset* implementation can obtain the default query by calling :meth:`Connection.get_reset_query() <\ asyncpg.connection.Connection.get_reset_query>`. :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. :return: An instance of :class:`~asyncpg.pool.Pool`. .. versionchanged:: 0.10.0 An :exc:`~asyncpg.exceptions.InterfaceError` will be raised on any attempted operation on a released connection. .. versionchanged:: 0.13.0 An :exc:`~asyncpg.exceptions.InterfaceError` will be raised on any attempted operation on a prepared statement or a cursor created on a connection that has been released to the pool. .. versionchanged:: 0.13.0 An :exc:`~asyncpg.exceptions.InterfaceWarning` will be produced if there are any active listeners (added via :meth:`Connection.add_listener() ` or :meth:`Connection.add_log_listener() `) present on the connection at the moment of its release to the pool. .. versionchanged:: 0.22.0 Added the *record_class* parameter. .. versionchanged:: 0.30.0 Added the *connect* and *reset* parameters. """ return Pool( dsn, connection_class=connection_class, record_class=record_class, min_size=min_size, max_size=max_size, max_queries=max_queries, loop=loop, connect=connect, setup=setup, init=init, reset=reset, max_inactive_connection_lifetime=max_inactive_connection_lifetime, **connect_kwargs, ) ================================================ FILE: asyncpg/prepared_stmt.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import json import typing from . import connresource from . import cursor from . import exceptions class PreparedStatement(connresource.ConnectionResource): """A representation of a prepared statement.""" __slots__ = ('_state', '_query', '_last_status') def __init__(self, connection, query, state): super().__init__(connection) self._state = state self._query = query state.attach() self._last_status = None @connresource.guarded def get_name(self) -> str: """Return the name of this prepared statement. .. versionadded:: 0.25.0 """ return self._state.name @connresource.guarded def get_query(self) -> str: """Return the text of the query for this prepared statement. Example:: stmt = await connection.prepare('SELECT $1::int') assert stmt.get_query() == "SELECT $1::int" """ return self._query @connresource.guarded def get_statusmsg(self) -> str: """Return the status of the executed command. Example:: stmt = await connection.prepare('CREATE TABLE mytab (a int)') await stmt.fetch() assert stmt.get_statusmsg() == "CREATE TABLE" """ if self._last_status is None: return self._last_status return self._last_status.decode() @connresource.guarded def get_parameters(self): """Return a description of statement parameters types. :return: A tuple of :class:`asyncpg.types.Type`. Example:: stmt = await connection.prepare('SELECT ($1::int, $2::text)') print(stmt.get_parameters()) # Will print: # (Type(oid=23, name='int4', kind='scalar', schema='pg_catalog'), # Type(oid=25, name='text', kind='scalar', schema='pg_catalog')) """ return self._state._get_parameters() @connresource.guarded def get_attributes(self): """Return a description of relation attributes (columns). :return: A tuple of :class:`asyncpg.types.Attribute`. Example:: st = await self.con.prepare(''' SELECT typname, typnamespace FROM pg_type ''') print(st.get_attributes()) # Will print: # (Attribute( # name='typname', # type=Type(oid=19, name='name', kind='scalar', # schema='pg_catalog')), # Attribute( # name='typnamespace', # type=Type(oid=26, name='oid', kind='scalar', # schema='pg_catalog'))) """ return self._state._get_attributes() @connresource.guarded def cursor(self, *args, prefetch=None, timeout=None) -> cursor.CursorFactory: """Return a *cursor factory* for the prepared statement. :param args: Query arguments. :param int prefetch: The number of rows the *cursor iterator* will prefetch (defaults to ``50``.) :param float timeout: Optional timeout in seconds. :return: A :class:`~cursor.CursorFactory` object. """ return cursor.CursorFactory( self._connection, self._query, self._state, args, prefetch, timeout, self._state.record_class, ) @connresource.guarded async def explain(self, *args, analyze=False): """Return the execution plan of the statement. :param args: Query arguments. :param analyze: If ``True``, the statement will be executed and the run time statitics added to the return value. :return: An object representing the execution plan. This value is actually a deserialized JSON output of the SQL ``EXPLAIN`` command. """ query = 'EXPLAIN (FORMAT JSON, VERBOSE' if analyze: query += ', ANALYZE) ' else: query += ') ' query += self._state.query if analyze: # From PostgreSQL docs: # Important: Keep in mind that the statement is actually # executed when the ANALYZE option is used. Although EXPLAIN # will discard any output that a SELECT would return, other # side effects of the statement will happen as usual. If you # wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE, # MERGE, CREATE TABLE AS, or EXECUTE statement without letting # the command affect your data, use this approach: # BEGIN; # EXPLAIN ANALYZE ...; # ROLLBACK; tr = self._connection.transaction() await tr.start() try: data = await self._connection.fetchval(query, *args) finally: await tr.rollback() else: data = await self._connection.fetchval(query, *args) return json.loads(data) @connresource.guarded async def fetch(self, *args, timeout=None): r"""Execute the statement and return a list of :class:`Record` objects. :param str query: Query text :param args: Query arguments :param float timeout: Optional timeout value in seconds. :return: A list of :class:`Record` instances. """ data = await self.__bind_execute(args, 0, timeout) return data @connresource.guarded async def fetchval(self, *args, column=0, timeout=None): """Execute the statement and return a value in the first row. :param args: Query arguments. :param int column: Numeric index within the record of the value to return (defaults to 0). :param float timeout: Optional timeout value in seconds. If not specified, defaults to the value of ``command_timeout`` argument to the ``Connection`` instance constructor. :return: The value of the specified column of the first record. """ data = await self.__bind_execute(args, 1, timeout) if not data: return None return data[0][column] @connresource.guarded async def fetchrow(self, *args, timeout=None): """Execute the statement and return the first row. :param str query: Query text :param args: Query arguments :param float timeout: Optional timeout value in seconds. :return: The first row as a :class:`Record` instance. """ data = await self.__bind_execute(args, 1, timeout) if not data: return None return data[0] @connresource.guarded async def fetchmany(self, args, *, timeout=None): """Execute the statement and return a list of :class:`Record` objects. :param args: Query arguments. :param float timeout: Optional timeout value in seconds. :return: A list of :class:`Record` instances. .. versionadded:: 0.30.0 """ return await self.__do_execute( lambda protocol: protocol.bind_execute_many( self._state, args, portal_name='', timeout=timeout, return_rows=True, ) ) @connresource.guarded async def executemany(self, args, *, timeout: typing.Optional[float]=None): """Execute the statement for each sequence of arguments in *args*. :param args: An iterable containing sequences of arguments. :param float timeout: Optional timeout value in seconds. :return None: This method discards the results of the operations. .. versionadded:: 0.22.0 """ return await self.__do_execute( lambda protocol: protocol.bind_execute_many( self._state, args, portal_name='', timeout=timeout, return_rows=False, )) async def __do_execute(self, executor): protocol = self._connection._protocol try: return await executor(protocol) except exceptions.OutdatedSchemaCacheError: await self._connection.reload_schema_state() # We can not find all manually created prepared statements, so just # drop known cached ones in the `self._connection`. # Other manually created prepared statements will fail and # invalidate themselves (unfortunately, clearing caches again). self._state.mark_closed() raise async def __bind_execute(self, args, limit, timeout): data, status, _ = await self.__do_execute( lambda protocol: protocol.bind_execute( self._state, args, '', limit, True, timeout)) self._last_status = status return data def _check_open(self, meth_name): if self._state.closed: raise exceptions.InterfaceError( 'cannot call PreparedStmt.{}(): ' 'the prepared statement is closed'.format(meth_name)) def _check_conn_validity(self, meth_name): self._check_open(meth_name) super()._check_conn_validity(meth_name) def __del__(self): self._state.detach() self._connection._maybe_gc_stmt(self._state) ================================================ FILE: asyncpg/protocol/.gitignore ================================================ /*.c ================================================ FILE: asyncpg/protocol/__init__.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 # flake8: NOQA from __future__ import annotations from .protocol import Protocol, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP from .record import Record ================================================ FILE: asyncpg/protocol/codecs/__init__.py ================================================ ================================================ FILE: asyncpg/protocol/codecs/array.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from collections.abc import (Iterable as IterableABC, Mapping as MappingABC, Sized as SizedABC) from asyncpg import exceptions DEF ARRAY_MAXDIM = 6 # defined in postgresql/src/includes/c.h # "NULL" cdef Py_UCS4 *APG_NULL = [0x004E, 0x0055, 0x004C, 0x004C, 0x0000] ctypedef object (*encode_func_ex)(ConnectionSettings settings, WriteBuffer buf, object obj, const void *arg) ctypedef object (*decode_func_ex)(ConnectionSettings settings, FRBuffer *buf, const void *arg) cdef inline bint _is_trivial_container(object obj): return cpython.PyUnicode_Check(obj) or cpython.PyBytes_Check(obj) or \ cpythonx.PyByteArray_Check(obj) or cpythonx.PyMemoryView_Check(obj) cdef inline _is_array_iterable(object obj): return ( isinstance(obj, IterableABC) and isinstance(obj, SizedABC) and not _is_trivial_container(obj) and not isinstance(obj, MappingABC) ) cdef inline _is_sub_array_iterable(object obj): # Sub-arrays have a specialized check, because we treat # nested tuples as records. return _is_array_iterable(obj) and not cpython.PyTuple_Check(obj) cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims): cdef: ssize_t mylen = len(obj) ssize_t elemlen = -2 object it if mylen > _MAXINT32: raise ValueError('too many elements in array value') if ndims[0] > ARRAY_MAXDIM: raise ValueError( 'number of array dimensions ({}) exceed the maximum expected ({})'. format(ndims[0], ARRAY_MAXDIM)) dims[ndims[0] - 1] = mylen for elem in obj: if _is_sub_array_iterable(elem): if elemlen == -2: elemlen = len(elem) if elemlen > _MAXINT32: raise ValueError('too many elements in array value') ndims[0] += 1 _get_array_shape(elem, dims, ndims) else: if len(elem) != elemlen: raise ValueError('non-homogeneous array') else: if elemlen >= 0: raise ValueError('non-homogeneous array') else: elemlen = -1 cdef _write_array_data(ConnectionSettings settings, object obj, int32_t ndims, int32_t dim, WriteBuffer elem_data, encode_func_ex encoder, const void *encoder_arg): if dim < ndims - 1: for item in obj: _write_array_data(settings, item, ndims, dim + 1, elem_data, encoder, encoder_arg) else: for item in obj: if item is None: elem_data.write_int32(-1) else: try: encoder(settings, elem_data, item, encoder_arg) except TypeError as e: raise ValueError( 'invalid array element: {}'.format(e.args[0])) from None cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf, object obj, uint32_t elem_oid, encode_func_ex encoder, const void *encoder_arg): cdef: WriteBuffer elem_data int32_t dims[ARRAY_MAXDIM] int32_t ndims = 1 int32_t i if not _is_array_iterable(obj): raise TypeError( 'a sized iterable container expected (got type {!r})'.format( type(obj).__name__)) _get_array_shape(obj, dims, &ndims) elem_data = WriteBuffer.new() if ndims > 1: _write_array_data(settings, obj, ndims, 0, elem_data, encoder, encoder_arg) else: for i, item in enumerate(obj): if item is None: elem_data.write_int32(-1) else: try: encoder(settings, elem_data, item, encoder_arg) except TypeError as e: raise ValueError( 'invalid array element at index {}: {}'.format( i, e.args[0])) from None buf.write_int32(12 + 8 * ndims + elem_data.len()) # Number of dimensions buf.write_int32(ndims) # flags buf.write_int32(0) # element type buf.write_int32(elem_oid) # upper / lower bounds for i in range(ndims): buf.write_int32(dims[i]) buf.write_int32(1) # element data buf.write_buffer(elem_data) cdef _write_textarray_data(ConnectionSettings settings, object obj, int32_t ndims, int32_t dim, WriteBuffer array_data, encode_func_ex encoder, const void *encoder_arg, Py_UCS4 typdelim): cdef: ssize_t i = 0 int8_t delim = typdelim WriteBuffer elem_data Py_buffer pybuf const char *elem_str char ch ssize_t elem_len ssize_t quoted_elem_len bint need_quoting array_data.write_byte(b'{') if dim < ndims - 1: for item in obj: if i > 0: array_data.write_byte(delim) array_data.write_byte(b' ') _write_textarray_data(settings, item, ndims, dim + 1, array_data, encoder, encoder_arg, typdelim) i += 1 else: for item in obj: elem_data = WriteBuffer.new() if i > 0: array_data.write_byte(delim) array_data.write_byte(b' ') if item is None: array_data.write_bytes(b'NULL') i += 1 continue else: try: encoder(settings, elem_data, item, encoder_arg) except TypeError as e: raise ValueError( 'invalid array element: {}'.format( e.args[0])) from None # element string length (first four bytes are the encoded length.) elem_len = elem_data.len() - 4 if elem_len == 0: # Empty string array_data.write_bytes(b'""') else: cpython.PyObject_GetBuffer( elem_data, &pybuf, cpython.PyBUF_SIMPLE) elem_str = (pybuf.buf) + 4 try: if not apg_strcasecmp_char(elem_str, b'NULL'): array_data.write_byte(b'"') array_data.write_cstr(elem_str, 4) array_data.write_byte(b'"') else: quoted_elem_len = elem_len need_quoting = False for i in range(elem_len): ch = elem_str[i] if ch == b'"' or ch == b'\\': # Quotes and backslashes need escaping. quoted_elem_len += 1 need_quoting = True elif (ch == b'{' or ch == b'}' or ch == delim or apg_ascii_isspace(ch)): need_quoting = True if need_quoting: array_data.write_byte(b'"') if quoted_elem_len == elem_len: array_data.write_cstr(elem_str, elem_len) else: # Escaping required. for i in range(elem_len): ch = elem_str[i] if ch == b'"' or ch == b'\\': array_data.write_byte(b'\\') array_data.write_byte(ch) array_data.write_byte(b'"') else: array_data.write_cstr(elem_str, elem_len) finally: cpython.PyBuffer_Release(&pybuf) i += 1 array_data.write_byte(b'}') cdef inline textarray_encode(ConnectionSettings settings, WriteBuffer buf, object obj, encode_func_ex encoder, const void *encoder_arg, Py_UCS4 typdelim): cdef: WriteBuffer array_data int32_t dims[ARRAY_MAXDIM] int32_t ndims = 1 int32_t i if not _is_array_iterable(obj): raise TypeError( 'a sized iterable container expected (got type {!r})'.format( type(obj).__name__)) _get_array_shape(obj, dims, &ndims) array_data = WriteBuffer.new() _write_textarray_data(settings, obj, ndims, 0, array_data, encoder, encoder_arg, typdelim) buf.write_int32(array_data.len()) buf.write_buffer(array_data) cdef inline array_decode(ConnectionSettings settings, FRBuffer *buf, decode_func_ex decoder, const void *decoder_arg): cdef: int32_t ndims = hton.unpack_int32(frb_read(buf, 4)) int32_t flags = hton.unpack_int32(frb_read(buf, 4)) uint32_t elem_oid = hton.unpack_int32(frb_read(buf, 4)) list result int i int32_t elem_len int32_t elem_count = 1 FRBuffer elem_buf int32_t dims[ARRAY_MAXDIM] Codec elem_codec if ndims == 0: return [] if ndims > ARRAY_MAXDIM: raise exceptions.ProtocolError( 'number of array dimensions ({}) exceed the maximum expected ({})'. format(ndims, ARRAY_MAXDIM)) elif ndims < 0: raise exceptions.ProtocolError( 'unexpected array dimensions value: {}'.format(ndims)) for i in range(ndims): dims[i] = hton.unpack_int32(frb_read(buf, 4)) if dims[i] < 0: raise exceptions.ProtocolError( 'unexpected array dimension size: {}'.format(dims[i])) # Ignore the lower bound information frb_read(buf, 4) if ndims == 1: # Fast path for flat arrays elem_count = dims[0] result = cpython.PyList_New(elem_count) for i in range(elem_count): elem_len = hton.unpack_int32(frb_read(buf, 4)) if elem_len == -1: elem = None else: frb_slice_from(&elem_buf, buf, elem_len) elem = decoder(settings, &elem_buf, decoder_arg) cpython.Py_INCREF(elem) cpython.PyList_SET_ITEM(result, i, elem) else: result = _nested_array_decode(settings, buf, decoder, decoder_arg, ndims, dims, &elem_buf) return result cdef _nested_array_decode(ConnectionSettings settings, FRBuffer *buf, decode_func_ex decoder, const void *decoder_arg, int32_t ndims, int32_t *dims, FRBuffer *elem_buf): cdef: int32_t elem_len int64_t i, j int64_t array_len = 1 object elem, stride # An array of pointers to lists for each current array level. void *strides[ARRAY_MAXDIM] # An array of current positions at each array level. int32_t indexes[ARRAY_MAXDIM] for i in range(ndims): array_len *= dims[i] indexes[i] = 0 strides[i] = NULL if array_len == 0: # A multidimensional array with a zero-sized dimension? return [] elif array_len < 0: # Array length overflow raise exceptions.ProtocolError('array length overflow') for i in range(array_len): # Decode the element. elem_len = hton.unpack_int32(frb_read(buf, 4)) if elem_len == -1: elem = None else: elem = decoder(settings, frb_slice_from(elem_buf, buf, elem_len), decoder_arg) # Take an explicit reference for PyList_SET_ITEM in the below # loop expects this. cpython.Py_INCREF(elem) # Iterate over array dimentions and put the element in # the correctly nested sublist. for j in reversed(range(ndims)): if indexes[j] == 0: # Allocate the list for this array level. stride = cpython.PyList_New(dims[j]) strides[j] = stride # Take an explicit reference for PyList_SET_ITEM below # expects this. cpython.Py_INCREF(stride) stride = strides[j] cpython.PyList_SET_ITEM(stride, indexes[j], elem) indexes[j] += 1 if indexes[j] == dims[j] and j != 0: # This array level is full, continue the # ascent in the dimensions so that this level # sublist will be appened to the parent list. elem = stride # Reset the index, this will cause the # new list to be allocated on the next # iteration on this array axis. indexes[j] = 0 else: break stride = strides[0] # Since each element in strides has a refcount of 1, # returning strides[0] will increment it to 2, so # balance that. cpython.Py_DECREF(stride) return stride cdef textarray_decode(ConnectionSettings settings, FRBuffer *buf, decode_func_ex decoder, const void *decoder_arg, Py_UCS4 typdelim): cdef: Py_UCS4 *array_text str s # Make a copy of array data since we will be mutating it for # the purposes of element decoding. s = pgproto.text_decode(settings, buf) array_text = cpythonx.PyUnicode_AsUCS4Copy(s) try: return _textarray_decode( settings, array_text, decoder, decoder_arg, typdelim) except ValueError as e: raise exceptions.ProtocolError( 'malformed array literal {!r}: {}'.format(s, e.args[0])) finally: cpython.PyMem_Free(array_text) cdef _textarray_decode(ConnectionSettings settings, Py_UCS4 *array_text, decode_func_ex decoder, const void *decoder_arg, Py_UCS4 typdelim): cdef: bytearray array_bytes list result list new_stride Py_UCS4 *ptr int32_t ndims = 0 int32_t ubound = 0 int32_t lbound = 0 int32_t dims[ARRAY_MAXDIM] int32_t inferred_dims[ARRAY_MAXDIM] int32_t inferred_ndims = 0 void *strides[ARRAY_MAXDIM] int32_t indexes[ARRAY_MAXDIM] int32_t nest_level = 0 int32_t item_level = 0 bint end_of_array = False bint end_of_item = False bint has_quoting = False bint strip_spaces = False bint in_quotes = False Py_UCS4 *item_start Py_UCS4 *item_ptr Py_UCS4 *item_end int i object item str item_text FRBuffer item_buf char *pg_item_str ssize_t pg_item_len ptr = array_text while True: while apg_ascii_isspace(ptr[0]): ptr += 1 if ptr[0] != '[': # Finished parsing dimensions spec. break ptr += 1 # '[' if ndims > ARRAY_MAXDIM: raise ValueError( 'number of array dimensions ({}) exceed the ' 'maximum expected ({})'.format(ndims, ARRAY_MAXDIM)) ptr = apg_parse_int32(ptr, &ubound) if ptr == NULL: raise ValueError('missing array dimension value') if ptr[0] == ':': ptr += 1 lbound = ubound # [lower:upper] spec. We disregard the lbound for decoding. ptr = apg_parse_int32(ptr, &ubound) if ptr == NULL: raise ValueError('missing array dimension value') else: lbound = 1 if ptr[0] != ']': raise ValueError('missing \']\' after array dimensions') ptr += 1 # ']' dims[ndims] = ubound - lbound + 1 ndims += 1 if ndims != 0: # If dimensions were given, the '=' token is expected. if ptr[0] != '=': raise ValueError('missing \'=\' after array dimensions') ptr += 1 # '=' # Skip any whitespace after the '=', whitespace # before was consumed in the above loop. while apg_ascii_isspace(ptr[0]): ptr += 1 # Infer the dimensions from the brace structure in the # array literal body, and check that it matches the explicit # spec. This also validates that the array literal is sane. _infer_array_dims(ptr, typdelim, inferred_dims, &inferred_ndims) if inferred_ndims != ndims: raise ValueError( 'specified array dimensions do not match array content') for i in range(ndims): if inferred_dims[i] != dims[i]: raise ValueError( 'specified array dimensions do not match array content') else: # Infer the dimensions from the brace structure in the array literal # body. This also validates that the array literal is sane. _infer_array_dims(ptr, typdelim, dims, &ndims) while not end_of_array: # We iterate over the literal character by character # and modify the string in-place removing the array-specific # quoting and determining the boundaries of each element. end_of_item = has_quoting = in_quotes = False strip_spaces = True # Pointers to array element start, end, and the current pointer # tracking the position where characters are written when # escaping is folded. item_start = item_end = item_ptr = ptr item_level = 0 while not end_of_item: if ptr[0] == '"': in_quotes = not in_quotes if in_quotes: strip_spaces = False else: item_end = item_ptr has_quoting = True elif ptr[0] == '\\': # Quoted character, collapse the backslash. ptr += 1 has_quoting = True item_ptr[0] = ptr[0] item_ptr += 1 strip_spaces = False item_end = item_ptr elif in_quotes: # Consume the string until we see the closing quote. item_ptr[0] = ptr[0] item_ptr += 1 elif ptr[0] == '{': # Nesting level increase. nest_level += 1 indexes[nest_level - 1] = 0 new_stride = cpython.PyList_New(dims[nest_level - 1]) strides[nest_level - 1] = \ (new_stride) if nest_level > 1: cpython.Py_INCREF(new_stride) cpython.PyList_SET_ITEM( strides[nest_level - 2], indexes[nest_level - 2], new_stride) else: result = new_stride elif ptr[0] == '}': if item_level == 0: # Make sure we keep track of which nesting # level the item belongs to, as the loop # will continue to consume closing braces # until the delimiter or the end of input. item_level = nest_level nest_level -= 1 if nest_level == 0: end_of_array = end_of_item = True elif ptr[0] == typdelim: # Array element delimiter, end_of_item = True if item_level == 0: item_level = nest_level elif apg_ascii_isspace(ptr[0]): if not strip_spaces: item_ptr[0] = ptr[0] item_ptr += 1 # Ignore the leading literal whitespace. else: item_ptr[0] = ptr[0] item_ptr += 1 strip_spaces = False item_end = item_ptr ptr += 1 # end while not end_of_item if item_end == item_start: # Empty array continue item_end[0] = '\0' if not has_quoting and apg_strcasecmp(item_start, APG_NULL) == 0: # NULL element. item = None else: # XXX: find a way to avoid the redundant encode/decode # cycle here. item_text = cpythonx.PyUnicode_FromKindAndData( cpythonx.PyUnicode_4BYTE_KIND, item_start, item_end - item_start) # Prepare the element buffer and call the text decoder # for the element type. pgproto.as_pg_string_and_size( settings, item_text, &pg_item_str, &pg_item_len) frb_init(&item_buf, pg_item_str, pg_item_len) item = decoder(settings, &item_buf, decoder_arg) # Place the decoded element in the array. cpython.Py_INCREF(item) cpython.PyList_SET_ITEM( strides[item_level - 1], indexes[item_level - 1], item) if nest_level > 0: indexes[nest_level - 1] += 1 return result cdef enum _ArrayParseState: APS_START = 1 APS_STRIDE_STARTED = 2 APS_STRIDE_DONE = 3 APS_STRIDE_DELIMITED = 4 APS_ELEM_STARTED = 5 APS_ELEM_DELIMITED = 6 cdef _UnexpectedCharacter(const Py_UCS4 *array_text, const Py_UCS4 *ptr): return ValueError('unexpected character {!r} at position {}'.format( cpython.PyUnicode_FromOrdinal(ptr[0]), ptr - array_text + 1)) cdef _infer_array_dims(const Py_UCS4 *array_text, Py_UCS4 typdelim, int32_t *dims, int32_t *ndims): cdef: const Py_UCS4 *ptr = array_text int i int nest_level = 0 bint end_of_array = False bint end_of_item = False bint in_quotes = False bint array_is_empty = True int stride_len[ARRAY_MAXDIM] int prev_stride_len[ARRAY_MAXDIM] _ArrayParseState parse_state = APS_START for i in range(ARRAY_MAXDIM): dims[i] = prev_stride_len[i] = 0 stride_len[i] = 1 while not end_of_array: end_of_item = False while not end_of_item: if ptr[0] == '\0': raise ValueError('unexpected end of string') elif ptr[0] == '"': if (parse_state not in (APS_STRIDE_STARTED, APS_ELEM_DELIMITED) and not (parse_state == APS_ELEM_STARTED and in_quotes)): raise _UnexpectedCharacter(array_text, ptr) in_quotes = not in_quotes if in_quotes: parse_state = APS_ELEM_STARTED array_is_empty = False elif ptr[0] == '\\': if parse_state not in (APS_STRIDE_STARTED, APS_ELEM_STARTED, APS_ELEM_DELIMITED): raise _UnexpectedCharacter(array_text, ptr) parse_state = APS_ELEM_STARTED array_is_empty = False if ptr[1] != '\0': ptr += 1 else: raise ValueError('unexpected end of string') elif in_quotes: # Ignore everything inside the quotes. pass elif ptr[0] == '{': if parse_state not in (APS_START, APS_STRIDE_STARTED, APS_STRIDE_DELIMITED): raise _UnexpectedCharacter(array_text, ptr) parse_state = APS_STRIDE_STARTED if nest_level >= ARRAY_MAXDIM: raise ValueError( 'number of array dimensions ({}) exceed the ' 'maximum expected ({})'.format( nest_level, ARRAY_MAXDIM)) dims[nest_level] = 0 nest_level += 1 if ndims[0] < nest_level: ndims[0] = nest_level elif ptr[0] == '}': if (parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE) and not (nest_level == 1 and parse_state == APS_STRIDE_STARTED)): raise _UnexpectedCharacter(array_text, ptr) parse_state = APS_STRIDE_DONE if nest_level == 0: raise _UnexpectedCharacter(array_text, ptr) nest_level -= 1 if (prev_stride_len[nest_level] != 0 and stride_len[nest_level] != prev_stride_len[nest_level]): raise ValueError( 'inconsistent sub-array dimensions' ' at position {}'.format( ptr - array_text + 1)) prev_stride_len[nest_level] = stride_len[nest_level] stride_len[nest_level] = 1 if nest_level == 0: end_of_array = end_of_item = True else: dims[nest_level - 1] += 1 elif ptr[0] == typdelim: if parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE): raise _UnexpectedCharacter(array_text, ptr) if parse_state == APS_STRIDE_DONE: parse_state = APS_STRIDE_DELIMITED else: parse_state = APS_ELEM_DELIMITED end_of_item = True stride_len[nest_level - 1] += 1 elif not apg_ascii_isspace(ptr[0]): if parse_state not in (APS_STRIDE_STARTED, APS_ELEM_STARTED, APS_ELEM_DELIMITED): raise _UnexpectedCharacter(array_text, ptr) parse_state = APS_ELEM_STARTED array_is_empty = False if not end_of_item: ptr += 1 if not array_is_empty: dims[ndims[0] - 1] += 1 ptr += 1 # only whitespace is allowed after the closing brace while ptr[0] != '\0': if not apg_ascii_isspace(ptr[0]): raise _UnexpectedCharacter(array_text, ptr) ptr += 1 if array_is_empty: ndims[0] = 0 cdef uint4_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj, const void *arg): return pgproto.uint4_encode(settings, buf, obj) cdef uint4_decode_ex(ConnectionSettings settings, FRBuffer *buf, const void *arg): return pgproto.uint4_decode(settings, buf) cdef arrayoid_encode(ConnectionSettings settings, WriteBuffer buf, items): array_encode(settings, buf, items, OIDOID, &uint4_encode_ex, NULL) cdef arrayoid_decode(ConnectionSettings settings, FRBuffer *buf): return array_decode(settings, buf, &uint4_decode_ex, NULL) cdef text_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj, const void *arg): return pgproto.text_encode(settings, buf, obj) cdef text_decode_ex(ConnectionSettings settings, FRBuffer *buf, const void *arg): return pgproto.text_decode(settings, buf) cdef arraytext_encode(ConnectionSettings settings, WriteBuffer buf, items): array_encode(settings, buf, items, TEXTOID, &text_encode_ex, NULL) cdef arraytext_decode(ConnectionSettings settings, FRBuffer *buf): return array_decode(settings, buf, &text_decode_ex, NULL) cdef init_array_codecs(): # oid[] and text[] are registered as core codecs # to make type introspection query work # register_core_codec(_OIDOID, &arrayoid_encode, &arrayoid_decode, PG_FORMAT_BINARY) register_core_codec(_TEXTOID, &arraytext_encode, &arraytext_decode, PG_FORMAT_BINARY) init_array_codecs() ================================================ FILE: asyncpg/protocol/codecs/base.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 ctypedef object (*encode_func)(ConnectionSettings settings, WriteBuffer buf, object obj) ctypedef object (*decode_func)(ConnectionSettings settings, FRBuffer *buf) ctypedef object (*codec_encode_func)(Codec codec, ConnectionSettings settings, WriteBuffer buf, object obj) ctypedef object (*codec_decode_func)(Codec codec, ConnectionSettings settings, FRBuffer *buf) cdef class CodecMap: cdef: void** binary_codec_map void** text_codec_map dict extra_codecs cdef inline void *get_binary_codec_ptr(self, uint32_t idx) cdef inline void set_binary_codec_ptr(self, uint32_t idx, void *ptr) cdef inline void *get_text_codec_ptr(self, uint32_t idx) cdef inline void set_text_codec_ptr(self, uint32_t idx, void *ptr) cdef enum CodecType: CODEC_UNDEFINED = 0 CODEC_C = 1 CODEC_PY = 2 CODEC_ARRAY = 3 CODEC_COMPOSITE = 4 CODEC_RANGE = 5 CODEC_MULTIRANGE = 6 cdef enum ServerDataFormat: PG_FORMAT_ANY = -1 PG_FORMAT_TEXT = 0 PG_FORMAT_BINARY = 1 cdef enum ClientExchangeFormat: PG_XFORMAT_OBJECT = 1 PG_XFORMAT_TUPLE = 2 cdef class Codec: cdef: uint32_t oid str name str schema str kind CodecType type ServerDataFormat format ClientExchangeFormat xformat encode_func c_encoder decode_func c_decoder Codec base_codec object py_encoder object py_decoder # arrays Codec element_codec Py_UCS4 element_delimiter # composite types tuple element_type_oids object element_names object record_desc list element_codecs # Pointers to actual encoder/decoder functions for this codec codec_encode_func encoder codec_decode_func decoder cdef init(self, str name, str schema, str kind, CodecType type, ServerDataFormat format, ClientExchangeFormat xformat, encode_func c_encoder, decode_func c_decoder, Codec base_codec, object py_encoder, object py_decoder, Codec element_codec, tuple element_type_oids, object element_names, list element_codecs, Py_UCS4 element_delimiter) cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf, object obj) cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf, object obj) cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf, object obj) cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf, object obj) cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf, object obj) cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf, object obj) cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf, object obj) cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf) cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf) cdef decode_array_text(self, ConnectionSettings settings, FRBuffer *buf) cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf) cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf) cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf) cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf) cdef inline encode(self, ConnectionSettings settings, WriteBuffer buf, object obj) cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf) cdef has_encoder(self) cdef has_decoder(self) cdef is_binary(self) cdef inline Codec copy(self) @staticmethod cdef Codec new_array_codec(uint32_t oid, str name, str schema, Codec element_codec, Py_UCS4 element_delimiter) @staticmethod cdef Codec new_range_codec(uint32_t oid, str name, str schema, Codec element_codec) @staticmethod cdef Codec new_multirange_codec(uint32_t oid, str name, str schema, Codec element_codec) @staticmethod cdef Codec new_composite_codec(uint32_t oid, str name, str schema, ServerDataFormat format, list element_codecs, tuple element_type_oids, object element_names) @staticmethod cdef Codec new_python_codec(uint32_t oid, str name, str schema, str kind, object encoder, object decoder, encode_func c_encoder, decode_func c_decoder, Codec base_codec, ServerDataFormat format, ClientExchangeFormat xformat) cdef class DataCodecConfig: cdef: dict _derived_type_codecs dict _custom_type_codecs cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, bint ignore_custom_codec=*) cdef inline Codec get_custom_codec(self, uint32_t oid, ServerDataFormat format) ================================================ FILE: asyncpg/protocol/codecs/base.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from collections.abc import Mapping as MappingABC import asyncpg from asyncpg import exceptions # The class indirection is needed because Cython # does not (as of 3.1.0) store global cdef variables # in module state. @cython.final cdef class CodecMap: def __cinit__(self): self.extra_codecs = {} self.binary_codec_map = cpython.PyMem_Calloc( (MAXSUPPORTEDOID + 1) * 2, sizeof(void *)) self.text_codec_map = cpython.PyMem_Calloc( (MAXSUPPORTEDOID + 1) * 2, sizeof(void *)) cdef inline void *get_binary_codec_ptr(self, uint32_t idx): return self.binary_codec_map[idx] cdef inline void set_binary_codec_ptr(self, uint32_t idx, void *ptr): self.binary_codec_map[idx] = ptr cdef inline void *get_text_codec_ptr(self, uint32_t idx): return self.text_codec_map[idx] cdef inline void set_text_codec_ptr(self, uint32_t idx, void *ptr): self.text_codec_map[idx] = ptr codec_map = CodecMap() @cython.final cdef class Codec: def __cinit__(self, uint32_t oid): self.oid = oid self.type = CODEC_UNDEFINED cdef init( self, str name, str schema, str kind, CodecType type, ServerDataFormat format, ClientExchangeFormat xformat, encode_func c_encoder, decode_func c_decoder, Codec base_codec, object py_encoder, object py_decoder, Codec element_codec, tuple element_type_oids, object element_names, list element_codecs, Py_UCS4 element_delimiter, ): self.name = name self.schema = schema self.kind = kind self.type = type self.format = format self.xformat = xformat self.c_encoder = c_encoder self.c_decoder = c_decoder self.base_codec = base_codec self.py_encoder = py_encoder self.py_decoder = py_decoder self.element_codec = element_codec self.element_type_oids = element_type_oids self.element_codecs = element_codecs self.element_delimiter = element_delimiter self.element_names = element_names if base_codec is not None: if c_encoder != NULL or c_decoder != NULL: raise exceptions.InternalClientError( 'base_codec is mutually exclusive with c_encoder/c_decoder' ) if element_names is not None: self.record_desc = RecordDescriptor( element_names, tuple(element_names)) else: self.record_desc = None if type == CODEC_C: self.encoder = &self.encode_scalar self.decoder = &self.decode_scalar elif type == CODEC_ARRAY: if format == PG_FORMAT_BINARY: self.encoder = &self.encode_array self.decoder = &self.decode_array else: self.encoder = &self.encode_array_text self.decoder = &self.decode_array_text elif type == CODEC_RANGE: if format != PG_FORMAT_BINARY: raise exceptions.UnsupportedClientFeatureError( 'cannot decode type "{}"."{}": text encoding of ' 'range types is not supported'.format(schema, name)) self.encoder = &self.encode_range self.decoder = &self.decode_range elif type == CODEC_MULTIRANGE: if format != PG_FORMAT_BINARY: raise exceptions.UnsupportedClientFeatureError( 'cannot decode type "{}"."{}": text encoding of ' 'range types is not supported'.format(schema, name)) self.encoder = &self.encode_multirange self.decoder = &self.decode_multirange elif type == CODEC_COMPOSITE: if format != PG_FORMAT_BINARY: raise exceptions.UnsupportedClientFeatureError( 'cannot decode type "{}"."{}": text encoding of ' 'composite types is not supported'.format(schema, name)) self.encoder = &self.encode_composite self.decoder = &self.decode_composite elif type == CODEC_PY: self.encoder = &self.encode_in_python self.decoder = &self.decode_in_python else: raise exceptions.InternalClientError( 'unexpected codec type: {}'.format(type)) cdef Codec copy(self): cdef Codec codec codec = Codec(self.oid) codec.init(self.name, self.schema, self.kind, self.type, self.format, self.xformat, self.c_encoder, self.c_decoder, self.base_codec, self.py_encoder, self.py_decoder, self.element_codec, self.element_type_oids, self.element_names, self.element_codecs, self.element_delimiter) return codec cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf, object obj): self.c_encoder(settings, buf, obj) cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf, object obj): array_encode(settings, buf, obj, self.element_codec.oid, codec_encode_func_ex, (self.element_codec)) cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf, object obj): return textarray_encode(settings, buf, obj, codec_encode_func_ex, (self.element_codec), self.element_delimiter) cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf, object obj): range_encode(settings, buf, obj, self.element_codec.oid, codec_encode_func_ex, (self.element_codec)) cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf, object obj): multirange_encode(settings, buf, obj, self.element_codec.oid, codec_encode_func_ex, (self.element_codec)) cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf, object obj): cdef: WriteBuffer elem_data int i list elem_codecs = self.element_codecs ssize_t count ssize_t composite_size tuple rec if isinstance(obj, MappingABC): # Input is dict-like, form a tuple composite_size = len(self.element_type_oids) rec = cpython.PyTuple_New(composite_size) for i in range(composite_size): cpython.Py_INCREF(None) cpython.PyTuple_SET_ITEM(rec, i, None) for field in obj: try: i = self.element_names[field] except KeyError: raise ValueError( '{!r} is not a valid element of composite ' 'type {}'.format(field, self.name)) from None item = obj[field] cpython.Py_INCREF(item) cpython.PyTuple_SET_ITEM(rec, i, item) obj = rec count = len(obj) if count > _MAXINT32: raise ValueError('too many elements in composite type record') elem_data = WriteBuffer.new() i = 0 for item in obj: elem_data.write_int32(self.element_type_oids[i]) if item is None: elem_data.write_int32(-1) else: (elem_codecs[i]).encode(settings, elem_data, item) i += 1 record_encode_frame(settings, buf, elem_data, count) cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf, object obj): data = self.py_encoder(obj) if self.xformat == PG_XFORMAT_OBJECT: if self.format == PG_FORMAT_BINARY: pgproto.bytea_encode(settings, buf, data) elif self.format == PG_FORMAT_TEXT: pgproto.text_encode(settings, buf, data) else: raise exceptions.InternalClientError( 'unexpected data format: {}'.format(self.format)) elif self.xformat == PG_XFORMAT_TUPLE: if self.base_codec is not None: self.base_codec.encode(settings, buf, data) else: self.c_encoder(settings, buf, data) else: raise exceptions.InternalClientError( 'unexpected exchange format: {}'.format(self.xformat)) cdef encode(self, ConnectionSettings settings, WriteBuffer buf, object obj): return self.encoder(self, settings, buf, obj) cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf): return self.c_decoder(settings, buf) cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf): return array_decode(settings, buf, codec_decode_func_ex, (self.element_codec)) cdef decode_array_text(self, ConnectionSettings settings, FRBuffer *buf): return textarray_decode(settings, buf, codec_decode_func_ex, (self.element_codec), self.element_delimiter) cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf): return range_decode(settings, buf, codec_decode_func_ex, (self.element_codec)) cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf): return multirange_decode(settings, buf, codec_decode_func_ex, (self.element_codec)) cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf): cdef: object result ssize_t elem_count ssize_t i int32_t elem_len uint32_t elem_typ uint32_t received_elem_typ Codec elem_codec FRBuffer elem_buf elem_count = hton.unpack_int32(frb_read(buf, 4)) if elem_count != len(self.element_type_oids): raise exceptions.OutdatedSchemaCacheError( 'unexpected number of attributes of composite type: ' '{}, expected {}' .format( elem_count, len(self.element_type_oids), ), schema=self.schema, data_type=self.name, ) result = self.record_desc.make_record(asyncpg.Record, elem_count) for i in range(elem_count): elem_typ = self.element_type_oids[i] received_elem_typ = hton.unpack_int32(frb_read(buf, 4)) if received_elem_typ != elem_typ: raise exceptions.OutdatedSchemaCacheError( 'unexpected data type of composite type attribute {}: ' '{!r}, expected {!r}' .format( i, BUILTIN_TYPE_OID_MAP.get( received_elem_typ, received_elem_typ), BUILTIN_TYPE_OID_MAP.get( elem_typ, elem_typ) ), schema=self.schema, data_type=self.name, position=i, ) elem_len = hton.unpack_int32(frb_read(buf, 4)) if elem_len == -1: elem = None else: elem_codec = self.element_codecs[i] elem = elem_codec.decode( settings, frb_slice_from(&elem_buf, buf, elem_len)) cpython.Py_INCREF(elem) recordcapi.ApgRecord_SET_ITEM(result, i, elem) return result cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf): if self.xformat == PG_XFORMAT_OBJECT: if self.format == PG_FORMAT_BINARY: data = pgproto.bytea_decode(settings, buf) elif self.format == PG_FORMAT_TEXT: data = pgproto.text_decode(settings, buf) else: raise exceptions.InternalClientError( 'unexpected data format: {}'.format(self.format)) elif self.xformat == PG_XFORMAT_TUPLE: if self.base_codec is not None: data = self.base_codec.decode(settings, buf) else: data = self.c_decoder(settings, buf) else: raise exceptions.InternalClientError( 'unexpected exchange format: {}'.format(self.xformat)) return self.py_decoder(data) cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf): return self.decoder(self, settings, buf) cdef inline has_encoder(self): cdef Codec elem_codec if self.c_encoder is not NULL or self.py_encoder is not None: return True elif ( self.type == CODEC_ARRAY or self.type == CODEC_RANGE or self.type == CODEC_MULTIRANGE ): return self.element_codec.has_encoder() elif self.type == CODEC_COMPOSITE: for elem_codec in self.element_codecs: if not elem_codec.has_encoder(): return False return True else: return False cdef has_decoder(self): cdef Codec elem_codec if self.c_decoder is not NULL or self.py_decoder is not None: return True elif ( self.type == CODEC_ARRAY or self.type == CODEC_RANGE or self.type == CODEC_MULTIRANGE ): return self.element_codec.has_decoder() elif self.type == CODEC_COMPOSITE: for elem_codec in self.element_codecs: if not elem_codec.has_decoder(): return False return True else: return False cdef is_binary(self): return self.format == PG_FORMAT_BINARY def __repr__(self): return ''.format( self.oid, 'NA' if self.element_codec is None else self.element_codec.oid, has_core_codec(self.oid)) @staticmethod cdef Codec new_array_codec(uint32_t oid, str name, str schema, Codec element_codec, Py_UCS4 element_delimiter): cdef Codec codec codec = Codec(oid) codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, element_codec, None, None, None, element_delimiter) return codec @staticmethod cdef Codec new_range_codec(uint32_t oid, str name, str schema, Codec element_codec): cdef Codec codec codec = Codec(oid) codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, element_codec, None, None, None, 0) return codec @staticmethod cdef Codec new_multirange_codec(uint32_t oid, str name, str schema, Codec element_codec): cdef Codec codec codec = Codec(oid) codec.init(name, schema, 'multirange', CODEC_MULTIRANGE, element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, element_codec, None, None, None, 0) return codec @staticmethod cdef Codec new_composite_codec(uint32_t oid, str name, str schema, ServerDataFormat format, list element_codecs, tuple element_type_oids, object element_names): cdef Codec codec codec = Codec(oid) codec.init(name, schema, 'composite', CODEC_COMPOSITE, format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None, None, element_type_oids, element_names, element_codecs, 0) return codec @staticmethod cdef Codec new_python_codec(uint32_t oid, str name, str schema, str kind, object encoder, object decoder, encode_func c_encoder, decode_func c_decoder, Codec base_codec, ServerDataFormat format, ClientExchangeFormat xformat): cdef Codec codec codec = Codec(oid) codec.init(name, schema, kind, CODEC_PY, format, xformat, c_encoder, c_decoder, base_codec, encoder, decoder, None, None, None, None, 0) return codec # Encode callback for arrays cdef codec_encode_func_ex(ConnectionSettings settings, WriteBuffer buf, object obj, const void *arg): return (arg).encode(settings, buf, obj) # Decode callback for arrays cdef codec_decode_func_ex(ConnectionSettings settings, FRBuffer *buf, const void *arg): return (arg).decode(settings, buf) cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl: cdef: int64_t oid = 0 bint overflow = False try: oid = cpython.PyLong_AsLongLong(val) except OverflowError: overflow = True if overflow or (oid < 0 or oid > UINT32_MAX): raise OverflowError('OID value too large: {!r}'.format(val)) return val cdef class DataCodecConfig: def __init__(self): # Codec instance cache for derived types: # composites, arrays, ranges, domains and their combinations. self._derived_type_codecs = {} # Codec instances set up by the user for the connection. self._custom_type_codecs = {} def add_types(self, types): cdef: Codec elem_codec list comp_elem_codecs ServerDataFormat format ServerDataFormat elem_format bint has_text_elements Py_UCS4 elem_delim for ti in types: oid = ti['oid'] if self.get_codec(oid, PG_FORMAT_ANY) is not None: continue name = ti['name'] schema = ti['ns'] array_element_oid = ti['elemtype'] range_subtype_oid = ti['range_subtype'] if ti['attrtypoids']: comp_type_attrs = tuple(ti['attrtypoids']) else: comp_type_attrs = None base_type = ti['basetype'] if array_element_oid: # Array type (note, there is no separate 'kind' for arrays) # Canonicalize type name to "elemtype[]" if name.startswith('_'): name = name[1:] name = '{}[]'.format(name) elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY) if elem_codec is None: elem_codec = self.declare_fallback_codec( array_element_oid, ti['elemtype_name'], schema) elem_delim = ti['elemdelim'][0] self._derived_type_codecs[oid, elem_codec.format] = \ Codec.new_array_codec( oid, name, schema, elem_codec, elem_delim) elif ti['kind'] == b'c': # Composite type if not comp_type_attrs: raise exceptions.InternalClientError( f'type record missing field types for composite {oid}') comp_elem_codecs = [] has_text_elements = False for typoid in comp_type_attrs: elem_codec = self.get_codec(typoid, PG_FORMAT_ANY) if elem_codec is None: raise exceptions.InternalClientError( f'no codec for composite attribute type {typoid}') if elem_codec.format is PG_FORMAT_TEXT: has_text_elements = True comp_elem_codecs.append(elem_codec) element_names = collections.OrderedDict() for i, attrname in enumerate(ti['attrnames']): element_names[attrname] = i # If at least one element is text-encoded, we must # encode the whole composite as text. if has_text_elements: elem_format = PG_FORMAT_TEXT else: elem_format = PG_FORMAT_BINARY self._derived_type_codecs[oid, elem_format] = \ Codec.new_composite_codec( oid, name, schema, elem_format, comp_elem_codecs, comp_type_attrs, element_names) elif ti['kind'] == b'd': # Domain type if not base_type: raise exceptions.InternalClientError( f'type record missing base type for domain {oid}') elem_codec = self.get_codec(base_type, PG_FORMAT_ANY) if elem_codec is None: elem_codec = self.declare_fallback_codec( base_type, ti['basetype_name'], schema) self._derived_type_codecs[oid, elem_codec.format] = elem_codec elif ti['kind'] == b'r': # Range type if not range_subtype_oid: raise exceptions.InternalClientError( f'type record missing base type for range {oid}') elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY) if elem_codec is None: elem_codec = self.declare_fallback_codec( range_subtype_oid, ti['range_subtype_name'], schema) self._derived_type_codecs[oid, elem_codec.format] = \ Codec.new_range_codec(oid, name, schema, elem_codec) elif ti['kind'] == b'm': # Multirange type if not range_subtype_oid: raise exceptions.InternalClientError( f'type record missing base type for multirange {oid}') elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY) if elem_codec is None: elem_codec = self.declare_fallback_codec( range_subtype_oid, ti['range_subtype_name'], schema) self._derived_type_codecs[oid, elem_codec.format] = \ Codec.new_multirange_codec(oid, name, schema, elem_codec) elif ti['kind'] == b'e': # Enum types are essentially text self._set_builtin_type_codec(oid, name, schema, 'scalar', TEXTOID, PG_FORMAT_ANY) else: self.declare_fallback_codec(oid, name, schema) def add_python_codec(self, typeoid, typename, typeschema, typekind, typeinfos, encoder, decoder, format, xformat): cdef: Codec core_codec = None encode_func c_encoder = NULL decode_func c_decoder = NULL Codec base_codec = None uint32_t oid = pylong_as_oid(typeoid) bint codec_set = False # Clear all previous overrides (this also clears type cache). self.remove_python_codec(typeoid, typename, typeschema) if typeinfos: self.add_types(typeinfos) if format == PG_FORMAT_ANY: formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY) else: formats = (format,) for fmt in formats: if xformat == PG_XFORMAT_TUPLE: if typekind == "scalar": core_codec = get_core_codec(oid, fmt, xformat) if core_codec is None: continue c_encoder = core_codec.c_encoder c_decoder = core_codec.c_decoder elif typekind == "composite": base_codec = self.get_codec(oid, fmt) if base_codec is None: continue self._custom_type_codecs[typeoid, fmt] = \ Codec.new_python_codec(oid, typename, typeschema, typekind, encoder, decoder, c_encoder, c_decoder, base_codec, fmt, xformat) codec_set = True if not codec_set: raise exceptions.InterfaceError( "{} type does not support the 'tuple' exchange format".format( typename)) def remove_python_codec(self, typeoid, typename, typeschema): for fmt in (PG_FORMAT_BINARY, PG_FORMAT_TEXT): self._custom_type_codecs.pop((typeoid, fmt), None) self.clear_type_cache() def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, alias_to, format=PG_FORMAT_ANY): cdef: Codec codec Codec target_codec uint32_t oid = pylong_as_oid(typeoid) uint32_t alias_oid = 0 bint codec_set = False if format == PG_FORMAT_ANY: formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT) else: formats = (format,) if isinstance(alias_to, int): alias_oid = pylong_as_oid(alias_to) else: alias_oid = BUILTIN_TYPE_NAME_MAP.get(alias_to, 0) for format in formats: if alias_oid != 0: target_codec = self.get_codec(alias_oid, format) else: target_codec = get_extra_codec(alias_to, format) if target_codec is None: continue codec = target_codec.copy() codec.oid = typeoid codec.name = typename codec.schema = typeschema codec.kind = typekind self._custom_type_codecs[typeoid, format] = codec codec_set = True if not codec_set: if format == PG_FORMAT_BINARY: codec_str = 'binary' elif format == PG_FORMAT_TEXT: codec_str = 'text' else: codec_str = 'text or binary' raise exceptions.InterfaceError( f'cannot alias {typename} to {alias_to}: ' f'there is no {codec_str} codec for {alias_to}') def set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, alias_to, format=PG_FORMAT_ANY): self._set_builtin_type_codec(typeoid, typename, typeschema, typekind, alias_to, format) self.clear_type_cache() def clear_type_cache(self): self._derived_type_codecs.clear() def declare_fallback_codec(self, uint32_t oid, str name, str schema): cdef Codec codec if oid <= MAXBUILTINOID: # This is a BKI type, for which asyncpg has no # defined codec. This should only happen for newly # added builtin types, for which this version of # asyncpg is lacking support. # raise exceptions.UnsupportedClientFeatureError( f'unhandled standard data type {name!r} (OID {oid})') else: # This is a non-BKI type, and as such, has no # stable OID, so no possibility of a builtin codec. # In this case, fallback to text format. Applications # can avoid this by specifying a codec for this type # using Connection.set_type_codec(). # self._set_builtin_type_codec(oid, name, schema, 'scalar', TEXTOID, PG_FORMAT_TEXT) codec = self.get_codec(oid, PG_FORMAT_TEXT) return codec cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format, bint ignore_custom_codec=False): cdef Codec codec if format == PG_FORMAT_ANY: codec = self.get_codec( oid, PG_FORMAT_BINARY, ignore_custom_codec) if codec is None: codec = self.get_codec( oid, PG_FORMAT_TEXT, ignore_custom_codec) return codec else: if not ignore_custom_codec: codec = self.get_custom_codec(oid, PG_FORMAT_ANY) if codec is not None: if codec.format != format: # The codec for this OID has been overridden by # set_{builtin}_type_codec with a different format. # We must respect that and not return a core codec. return None else: return codec codec = get_core_codec(oid, format) if codec is not None: return codec else: try: return self._derived_type_codecs[oid, format] except KeyError: return None cdef inline Codec get_custom_codec( self, uint32_t oid, ServerDataFormat format ): cdef Codec codec if format == PG_FORMAT_ANY: codec = self.get_custom_codec(oid, PG_FORMAT_BINARY) if codec is None: codec = self.get_custom_codec(oid, PG_FORMAT_TEXT) else: codec = self._custom_type_codecs.get((oid, format)) return codec cdef inline Codec get_core_codec( uint32_t oid, ServerDataFormat format, ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): cdef: void *ptr = NULL if oid > MAXSUPPORTEDOID: return None if format == PG_FORMAT_BINARY: ptr = (codec_map).get_binary_codec_ptr(oid * xformat) elif format == PG_FORMAT_TEXT: ptr = (codec_map).get_text_codec_ptr(oid * xformat) if ptr is NULL: return None else: return ptr cdef inline Codec get_any_core_codec( uint32_t oid, ServerDataFormat format, ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): """A version of get_core_codec that accepts PG_FORMAT_ANY.""" cdef: Codec codec if format == PG_FORMAT_ANY: codec = get_core_codec(oid, PG_FORMAT_BINARY, xformat) if codec is None: codec = get_core_codec(oid, PG_FORMAT_TEXT, xformat) else: codec = get_core_codec(oid, format, xformat) return codec cdef inline int has_core_codec(uint32_t oid): return ( (codec_map).get_binary_codec_ptr(oid) != NULL or (codec_map).get_text_codec_ptr(oid) != NULL ) cdef register_core_codec(uint32_t oid, encode_func encode, decode_func decode, ServerDataFormat format, ClientExchangeFormat xformat=PG_XFORMAT_OBJECT): if oid > MAXSUPPORTEDOID: raise exceptions.InternalClientError( 'cannot register core codec for OID {}: it is greater ' 'than MAXSUPPORTEDOID ({})'.format(oid, MAXSUPPORTEDOID)) cdef: Codec codec str name str kind name = BUILTIN_TYPE_OID_MAP[oid] kind = 'array' if oid in ARRAY_TYPES else 'scalar' codec = Codec(oid) codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat, encode, decode, None, None, None, None, None, None, None, 0) cpython.Py_INCREF(codec) # immortalize if format == PG_FORMAT_BINARY: (codec_map).set_binary_codec_ptr(oid * xformat, codec) elif format == PG_FORMAT_TEXT: (codec_map).set_text_codec_ptr(oid * xformat, codec) else: raise exceptions.InternalClientError( 'invalid data format: {}'.format(format)) cdef register_extra_codec(str name, encode_func encode, decode_func decode, ServerDataFormat format): cdef: Codec codec str kind kind = 'scalar' codec = Codec(INVALIDOID) codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT, encode, decode, None, None, None, None, None, None, None, 0) (codec_map).extra_codecs[name, format] = codec cdef inline Codec get_extra_codec(str name, ServerDataFormat format): return (codec_map).extra_codecs.get((name, format)) ================================================ FILE: asyncpg/protocol/codecs/pgproto.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 cdef init_bits_codecs(): register_core_codec(BITOID, pgproto.bits_encode, pgproto.bits_decode, PG_FORMAT_BINARY) register_core_codec(VARBITOID, pgproto.bits_encode, pgproto.bits_decode, PG_FORMAT_BINARY) cdef init_bytea_codecs(): register_core_codec(BYTEAOID, pgproto.bytea_encode, pgproto.bytea_decode, PG_FORMAT_BINARY) register_core_codec(CHAROID, pgproto.bytea_encode, pgproto.bytea_decode, PG_FORMAT_BINARY) cdef init_datetime_codecs(): register_core_codec(DATEOID, pgproto.date_encode, pgproto.date_decode, PG_FORMAT_BINARY) register_core_codec(DATEOID, pgproto.date_encode_tuple, pgproto.date_decode_tuple, PG_FORMAT_BINARY, PG_XFORMAT_TUPLE) register_core_codec(TIMEOID, pgproto.time_encode, pgproto.time_decode, PG_FORMAT_BINARY) register_core_codec(TIMEOID, pgproto.time_encode_tuple, pgproto.time_decode_tuple, PG_FORMAT_BINARY, PG_XFORMAT_TUPLE) register_core_codec(TIMETZOID, pgproto.timetz_encode, pgproto.timetz_decode, PG_FORMAT_BINARY) register_core_codec(TIMETZOID, pgproto.timetz_encode_tuple, pgproto.timetz_decode_tuple, PG_FORMAT_BINARY, PG_XFORMAT_TUPLE) register_core_codec(TIMESTAMPOID, pgproto.timestamp_encode, pgproto.timestamp_decode, PG_FORMAT_BINARY) register_core_codec(TIMESTAMPOID, pgproto.timestamp_encode_tuple, pgproto.timestamp_decode_tuple, PG_FORMAT_BINARY, PG_XFORMAT_TUPLE) register_core_codec(TIMESTAMPTZOID, pgproto.timestamptz_encode, pgproto.timestamptz_decode, PG_FORMAT_BINARY) register_core_codec(TIMESTAMPTZOID, pgproto.timestamp_encode_tuple, pgproto.timestamp_decode_tuple, PG_FORMAT_BINARY, PG_XFORMAT_TUPLE) register_core_codec(INTERVALOID, pgproto.interval_encode, pgproto.interval_decode, PG_FORMAT_BINARY) register_core_codec(INTERVALOID, pgproto.interval_encode_tuple, pgproto.interval_decode_tuple, PG_FORMAT_BINARY, PG_XFORMAT_TUPLE) # For obsolete abstime/reltime/tinterval, we do not bother to # interpret the value, and simply return and pass it as text. # register_core_codec(ABSTIMEOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) register_core_codec(RELTIMEOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) register_core_codec(TINTERVALOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) cdef init_float_codecs(): register_core_codec(FLOAT4OID, pgproto.float4_encode, pgproto.float4_decode, PG_FORMAT_BINARY) register_core_codec(FLOAT8OID, pgproto.float8_encode, pgproto.float8_decode, PG_FORMAT_BINARY) cdef init_geometry_codecs(): register_core_codec(BOXOID, pgproto.box_encode, pgproto.box_decode, PG_FORMAT_BINARY) register_core_codec(LINEOID, pgproto.line_encode, pgproto.line_decode, PG_FORMAT_BINARY) register_core_codec(LSEGOID, pgproto.lseg_encode, pgproto.lseg_decode, PG_FORMAT_BINARY) register_core_codec(POINTOID, pgproto.point_encode, pgproto.point_decode, PG_FORMAT_BINARY) register_core_codec(PATHOID, pgproto.path_encode, pgproto.path_decode, PG_FORMAT_BINARY) register_core_codec(POLYGONOID, pgproto.poly_encode, pgproto.poly_decode, PG_FORMAT_BINARY) register_core_codec(CIRCLEOID, pgproto.circle_encode, pgproto.circle_decode, PG_FORMAT_BINARY) cdef init_hstore_codecs(): register_extra_codec('pg_contrib.hstore', pgproto.hstore_encode, pgproto.hstore_decode, PG_FORMAT_BINARY) cdef init_json_codecs(): register_core_codec(JSONOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_BINARY) register_core_codec(JSONBOID, pgproto.jsonb_encode, pgproto.jsonb_decode, PG_FORMAT_BINARY) register_core_codec(JSONPATHOID, pgproto.jsonpath_encode, pgproto.jsonpath_decode, PG_FORMAT_BINARY) cdef init_int_codecs(): register_core_codec(BOOLOID, pgproto.bool_encode, pgproto.bool_decode, PG_FORMAT_BINARY) register_core_codec(INT2OID, pgproto.int2_encode, pgproto.int2_decode, PG_FORMAT_BINARY) register_core_codec(INT4OID, pgproto.int4_encode, pgproto.int4_decode, PG_FORMAT_BINARY) register_core_codec(INT8OID, pgproto.int8_encode, pgproto.int8_decode, PG_FORMAT_BINARY) cdef init_pseudo_codecs(): # Void type is returned by SELECT void_returning_function() register_core_codec(VOIDOID, pgproto.void_encode, pgproto.void_decode, PG_FORMAT_BINARY) # Unknown type, always decoded as text register_core_codec(UNKNOWNOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) # OID and friends oid_types = [ OIDOID, XIDOID, CIDOID ] for oid_type in oid_types: register_core_codec(oid_type, pgproto.uint4_encode, pgproto.uint4_decode, PG_FORMAT_BINARY) # 64-bit OID types oid8_types = [ XID8OID, ] for oid_type in oid8_types: register_core_codec(oid_type, pgproto.uint8_encode, pgproto.uint8_decode, PG_FORMAT_BINARY) # reg* types -- these are really system catalog OIDs, but # allow the catalog object name as an input. We could just # decode these as OIDs, but handling them as text seems more # useful. # reg_types = [ REGPROCOID, REGPROCEDUREOID, REGOPEROID, REGOPERATOROID, REGCLASSOID, REGTYPEOID, REGCONFIGOID, REGDICTIONARYOID, REGNAMESPACEOID, REGROLEOID, REFCURSOROID, REGCOLLATIONOID, ] for reg_type in reg_types: register_core_codec(reg_type, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) # cstring type is used by Postgres' I/O functions register_core_codec(CSTRINGOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_BINARY) # various system pseudotypes with no I/O no_io_types = [ ANYOID, TRIGGEROID, EVENT_TRIGGEROID, LANGUAGE_HANDLEROID, FDW_HANDLEROID, TSM_HANDLEROID, INTERNALOID, OPAQUEOID, ANYELEMENTOID, ANYNONARRAYOID, ANYCOMPATIBLEOID, ANYCOMPATIBLEARRAYOID, ANYCOMPATIBLENONARRAYOID, ANYCOMPATIBLERANGEOID, ANYCOMPATIBLEMULTIRANGEOID, ANYRANGEOID, ANYMULTIRANGEOID, ANYARRAYOID, PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID, TABLE_AM_HANDLEROID, ] register_core_codec(ANYENUMOID, NULL, pgproto.text_decode, PG_FORMAT_TEXT) for no_io_type in no_io_types: register_core_codec(no_io_type, NULL, NULL, PG_FORMAT_BINARY) # ACL specification string register_core_codec(ACLITEMOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) # Postgres' serialized expression tree type register_core_codec(PG_NODE_TREEOID, NULL, pgproto.text_decode, PG_FORMAT_TEXT) # pg_lsn type -- a pointer to a location in the XLOG. register_core_codec(PG_LSNOID, pgproto.int8_encode, pgproto.int8_decode, PG_FORMAT_BINARY) register_core_codec(SMGROID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) # pg_dependencies and pg_ndistinct are special types # used in pg_statistic_ext columns. register_core_codec(PG_DEPENDENCIESOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) register_core_codec(PG_NDISTINCTOID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) # pg_mcv_list is a special type used in pg_statistic_ext_data # system catalog register_core_codec(PG_MCV_LISTOID, pgproto.bytea_encode, pgproto.bytea_decode, PG_FORMAT_BINARY) # These two are internal to BRIN index support and are unlikely # to be sent, but since I/O functions for these exist, add decoders # nonetheless. register_core_codec(PG_BRIN_BLOOM_SUMMARYOID, NULL, pgproto.bytea_decode, PG_FORMAT_BINARY) register_core_codec(PG_BRIN_MINMAX_MULTI_SUMMARYOID, NULL, pgproto.bytea_decode, PG_FORMAT_BINARY) cdef init_text_codecs(): textoids = [ NAMEOID, BPCHAROID, VARCHAROID, TEXTOID, XMLOID ] for oid in textoids: register_core_codec(oid, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_BINARY) register_core_codec(oid, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) cdef init_tid_codecs(): register_core_codec(TIDOID, pgproto.tid_encode, pgproto.tid_decode, PG_FORMAT_BINARY) cdef init_txid_codecs(): register_core_codec(TXID_SNAPSHOTOID, pgproto.pg_snapshot_encode, pgproto.pg_snapshot_decode, PG_FORMAT_BINARY) register_core_codec(PG_SNAPSHOTOID, pgproto.pg_snapshot_encode, pgproto.pg_snapshot_decode, PG_FORMAT_BINARY) cdef init_tsearch_codecs(): ts_oids = [ TSQUERYOID, TSVECTOROID, ] for oid in ts_oids: register_core_codec(oid, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) register_core_codec(GTSVECTOROID, NULL, pgproto.text_decode, PG_FORMAT_TEXT) cdef init_uuid_codecs(): register_core_codec(UUIDOID, pgproto.uuid_encode, pgproto.uuid_decode, PG_FORMAT_BINARY) cdef init_numeric_codecs(): register_core_codec(NUMERICOID, pgproto.numeric_encode_text, pgproto.numeric_decode_text, PG_FORMAT_TEXT) register_core_codec(NUMERICOID, pgproto.numeric_encode_binary, pgproto.numeric_decode_binary, PG_FORMAT_BINARY) cdef init_network_codecs(): register_core_codec(CIDROID, pgproto.cidr_encode, pgproto.cidr_decode, PG_FORMAT_BINARY) register_core_codec(INETOID, pgproto.inet_encode, pgproto.inet_decode, PG_FORMAT_BINARY) register_core_codec(MACADDROID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) register_core_codec(MACADDR8OID, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) cdef init_monetary_codecs(): moneyoids = [ MONEYOID, ] for oid in moneyoids: register_core_codec(oid, pgproto.text_encode, pgproto.text_decode, PG_FORMAT_TEXT) cdef init_all_pgproto_codecs(): # Builtin types, in lexicographical order. init_bits_codecs() init_bytea_codecs() init_datetime_codecs() init_float_codecs() init_geometry_codecs() init_int_codecs() init_json_codecs() init_monetary_codecs() init_network_codecs() init_numeric_codecs() init_text_codecs() init_tid_codecs() init_tsearch_codecs() init_txid_codecs() init_uuid_codecs() # Various pseudotypes and system types init_pseudo_codecs() # contrib init_hstore_codecs() init_all_pgproto_codecs() ================================================ FILE: asyncpg/protocol/codecs/range.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from asyncpg import types as apg_types from collections.abc import Sequence as SequenceABC # defined in postgresql/src/include/utils/rangetypes.h DEF RANGE_EMPTY = 0x01 # range is empty DEF RANGE_LB_INC = 0x02 # lower bound is inclusive DEF RANGE_UB_INC = 0x04 # upper bound is inclusive DEF RANGE_LB_INF = 0x08 # lower bound is -infinity DEF RANGE_UB_INF = 0x10 # upper bound is +infinity cdef enum _RangeArgumentType: _RANGE_ARGUMENT_INVALID = 0 _RANGE_ARGUMENT_TUPLE = 1 _RANGE_ARGUMENT_RANGE = 2 cdef inline bint _range_has_lbound(uint8_t flags): return not (flags & (RANGE_EMPTY | RANGE_LB_INF)) cdef inline bint _range_has_ubound(uint8_t flags): return not (flags & (RANGE_EMPTY | RANGE_UB_INF)) cdef inline _RangeArgumentType _range_type(object obj): if cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj): return _RANGE_ARGUMENT_TUPLE elif isinstance(obj, apg_types.Range): return _RANGE_ARGUMENT_RANGE else: return _RANGE_ARGUMENT_INVALID cdef range_encode(ConnectionSettings settings, WriteBuffer buf, object obj, uint32_t elem_oid, encode_func_ex encoder, const void *encoder_arg): cdef: ssize_t obj_len uint8_t flags = 0 object lower = None object upper = None WriteBuffer bounds_data = WriteBuffer.new() _RangeArgumentType arg_type = _range_type(obj) if arg_type == _RANGE_ARGUMENT_INVALID: raise TypeError( 'list, tuple or Range object expected (got type {})'.format( type(obj))) elif arg_type == _RANGE_ARGUMENT_TUPLE: obj_len = len(obj) if obj_len == 2: lower = obj[0] upper = obj[1] if lower is None: flags |= RANGE_LB_INF if upper is None: flags |= RANGE_UB_INF flags |= RANGE_LB_INC | RANGE_UB_INC elif obj_len == 1: lower = obj[0] flags |= RANGE_LB_INC | RANGE_UB_INF elif obj_len == 0: flags |= RANGE_EMPTY else: raise ValueError( 'expected 0, 1 or 2 elements in range (got {})'.format( obj_len)) else: if obj.isempty: flags |= RANGE_EMPTY else: lower = obj.lower upper = obj.upper if obj.lower_inc: flags |= RANGE_LB_INC elif lower is None: flags |= RANGE_LB_INF if obj.upper_inc: flags |= RANGE_UB_INC elif upper is None: flags |= RANGE_UB_INF if _range_has_lbound(flags): encoder(settings, bounds_data, lower, encoder_arg) if _range_has_ubound(flags): encoder(settings, bounds_data, upper, encoder_arg) buf.write_int32(1 + bounds_data.len()) buf.write_byte(flags) buf.write_buffer(bounds_data) cdef range_decode(ConnectionSettings settings, FRBuffer *buf, decode_func_ex decoder, const void *decoder_arg): cdef: uint8_t flags = frb_read(buf, 1)[0] int32_t bound_len object lower = None object upper = None FRBuffer bound_buf if _range_has_lbound(flags): bound_len = hton.unpack_int32(frb_read(buf, 4)) if bound_len == -1: lower = None else: frb_slice_from(&bound_buf, buf, bound_len) lower = decoder(settings, &bound_buf, decoder_arg) if _range_has_ubound(flags): bound_len = hton.unpack_int32(frb_read(buf, 4)) if bound_len == -1: upper = None else: frb_slice_from(&bound_buf, buf, bound_len) upper = decoder(settings, &bound_buf, decoder_arg) return apg_types.Range(lower=lower, upper=upper, lower_inc=(flags & RANGE_LB_INC) != 0, upper_inc=(flags & RANGE_UB_INC) != 0, empty=(flags & RANGE_EMPTY) != 0) cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf, object obj, uint32_t elem_oid, encode_func_ex encoder, const void *encoder_arg): cdef: WriteBuffer elem_data ssize_t elem_data_len ssize_t elem_count if not isinstance(obj, SequenceABC): raise TypeError( 'expected a sequence (got type {!r})'.format(type(obj).__name__) ) elem_data = WriteBuffer.new() for elem in obj: range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg) elem_count = len(obj) if elem_count > INT32_MAX: raise OverflowError(f'too many elements in multirange value') elem_data_len = elem_data.len() if elem_data_len > INT32_MAX - 4: raise OverflowError( f'size of encoded multirange datum exceeds the maximum allowed' f' {INT32_MAX - 4} bytes') # Datum length buf.write_int32(4 + elem_data_len) # Number of elements in multirange buf.write_int32(elem_count) buf.write_buffer(elem_data) cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf, decode_func_ex decoder, const void *decoder_arg): cdef: int32_t nelems = hton.unpack_int32(frb_read(buf, 4)) FRBuffer elem_buf int32_t elem_len int i list result if nelems == 0: return [] if nelems < 0: raise exceptions.ProtocolError( 'unexpected multirange size value: {}'.format(nelems)) result = cpython.PyList_New(nelems) for i in range(nelems): elem_len = hton.unpack_int32(frb_read(buf, 4)) if elem_len == -1: raise exceptions.ProtocolError( 'unexpected NULL element in multirange value') else: frb_slice_from(&elem_buf, buf, elem_len) elem = range_decode(settings, &elem_buf, decoder, decoder_arg) cpython.Py_INCREF(elem) cpython.PyList_SET_ITEM(result, i, elem) return result ================================================ FILE: asyncpg/protocol/codecs/record.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from asyncpg import exceptions cdef inline record_encode_frame(ConnectionSettings settings, WriteBuffer buf, WriteBuffer elem_data, int32_t elem_count): buf.write_int32(4 + elem_data.len()) # attribute count buf.write_int32(elem_count) # encoded attribute data buf.write_buffer(elem_data) cdef anonymous_record_decode(ConnectionSettings settings, FRBuffer *buf): cdef: tuple result ssize_t elem_count ssize_t i int32_t elem_len uint32_t elem_typ Codec elem_codec FRBuffer elem_buf elem_count = hton.unpack_int32(frb_read(buf, 4)) result = cpython.PyTuple_New(elem_count) for i in range(elem_count): elem_typ = hton.unpack_int32(frb_read(buf, 4)) elem_len = hton.unpack_int32(frb_read(buf, 4)) if elem_len == -1: elem = None else: elem_codec = settings.get_data_codec(elem_typ) if elem_codec is None or not elem_codec.has_decoder(): raise exceptions.InternalClientError( 'no decoder for composite type element in ' 'position {} of type OID {}'.format(i, elem_typ)) elem = elem_codec.decode(settings, frb_slice_from(&elem_buf, buf, elem_len)) cpython.Py_INCREF(elem) cpython.PyTuple_SET_ITEM(result, i, elem) return result cdef anonymous_record_encode(ConnectionSettings settings, WriteBuffer buf, obj): raise exceptions.UnsupportedClientFeatureError( 'input of anonymous composite types is not supported', hint=( 'Consider declaring an explicit composite type and ' 'using it to cast the argument.' ), detail='PostgreSQL does not implement anonymous composite type input.' ) cdef init_record_codecs(): register_core_codec(RECORDOID, anonymous_record_encode, anonymous_record_decode, PG_FORMAT_BINARY) init_record_codecs() ================================================ FILE: asyncpg/protocol/codecs/textutils.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 cdef inline uint32_t _apg_tolower(uint32_t c): if c >= 'A' and c <= 'Z': return c + 'a' - 'A' else: return c cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2): cdef: uint32_t c1 uint32_t c2 int i = 0 while True: c1 = s1[i] c2 = s2[i] if c1 != c2: c1 = _apg_tolower(c1) c2 = _apg_tolower(c2) if c1 != c2: return c1 - c2 if c1 == 0 or c2 == 0: break i += 1 return 0 cdef int apg_strcasecmp_char(const char *s1, const char *s2): cdef: uint8_t c1 uint8_t c2 int i = 0 while True: c1 = s1[i] c2 = s2[i] if c1 != c2: c1 = _apg_tolower(c1) c2 = _apg_tolower(c2) if c1 != c2: return c1 - c2 if c1 == 0 or c2 == 0: break i += 1 return 0 cdef inline bint apg_ascii_isspace(Py_UCS4 ch): return ( ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or ch == '\v' or ch == '\f' ) cdef Py_UCS4 *apg_parse_int32(Py_UCS4 *buf, int32_t *num): cdef: Py_UCS4 *p int32_t n = 0 int32_t neg = 0 if buf[0] == '-': neg = 1 buf += 1 elif buf[0] == '+': buf += 1 p = buf while p[0] >= '0' and p[0] <= '9': n = 10 * n - (p[0] - '0') p += 1 if p == buf: return NULL if not neg: n = -n num[0] = n return p ================================================ FILE: asyncpg/protocol/consts.pxi ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 DEF _MAXINT32 = 2**31 - 1 DEF _COPY_BUFFER_SIZE = 524288 DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" DEF _EXECUTE_MANY_BUF_NUM = 4 DEF _EXECUTE_MANY_BUF_SIZE = 32768 ================================================ FILE: asyncpg/protocol/coreproto.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 include "scram.pxd" cdef enum ConnectionStatus: CONNECTION_OK = 1 CONNECTION_BAD = 2 CONNECTION_STARTED = 3 # Waiting for connection to be made. cdef enum ProtocolState: PROTOCOL_IDLE = 0 PROTOCOL_FAILED = 1 PROTOCOL_ERROR_CONSUME = 2 PROTOCOL_CANCELLED = 3 PROTOCOL_TERMINATING = 4 PROTOCOL_AUTH = 10 PROTOCOL_PREPARE = 11 PROTOCOL_BIND_EXECUTE = 12 PROTOCOL_BIND_EXECUTE_MANY = 13 PROTOCOL_CLOSE_STMT_PORTAL = 14 PROTOCOL_SIMPLE_QUERY = 15 PROTOCOL_EXECUTE = 16 PROTOCOL_BIND = 17 PROTOCOL_COPY_OUT = 18 PROTOCOL_COPY_OUT_DATA = 19 PROTOCOL_COPY_OUT_DONE = 20 PROTOCOL_COPY_IN = 21 PROTOCOL_COPY_IN_DATA = 22 cdef enum AuthenticationMessage: AUTH_SUCCESSFUL = 0 AUTH_REQUIRED_KERBEROS = 2 AUTH_REQUIRED_PASSWORD = 3 AUTH_REQUIRED_PASSWORDMD5 = 5 AUTH_REQUIRED_SCMCRED = 6 AUTH_REQUIRED_GSS = 7 AUTH_REQUIRED_GSS_CONTINUE = 8 AUTH_REQUIRED_SSPI = 9 AUTH_REQUIRED_SASL = 10 AUTH_SASL_CONTINUE = 11 AUTH_SASL_FINAL = 12 cdef enum ResultType: RESULT_OK = 1 RESULT_FAILED = 2 cdef enum TransactionStatus: PQTRANS_IDLE = 0 # connection idle PQTRANS_ACTIVE = 1 # command in progress PQTRANS_INTRANS = 2 # idle, within transaction block PQTRANS_INERROR = 3 # idle, within failed transaction PQTRANS_UNKNOWN = 4 # cannot determine status ctypedef object (*decode_row_method)(object, const char*, ssize_t) cdef class CoreProtocol: cdef: ReadBuffer buffer bint _skip_discard bint _discard_data # executemany support data object _execute_iter str _execute_portal_name str _execute_stmt_name ConnectionStatus con_status ProtocolState state TransactionStatus xact_status str encoding object transport object address # Instance of _ConnectionParameters object con_params # Instance of SCRAMAuthentication SCRAMAuthentication scram # Instance of gssapi.SecurityContext or sspilib.SecurityContext object gss_ctx readonly int32_t backend_pid readonly int32_t backend_secret ## Result ResultType result_type object result bytes result_param_desc bytes result_row_desc bytes result_status_msg # True - completed, False - suspended bint result_execute_completed cpdef is_in_transaction(self) cdef _process__auth(self, char mtype) cdef _process__prepare(self, char mtype) cdef _process__bind_execute(self, char mtype) cdef _process__bind_execute_many(self, char mtype) cdef _process__close_stmt_portal(self, char mtype) cdef _process__simple_query(self, char mtype) cdef _process__bind(self, char mtype) cdef _process__copy_out(self, char mtype) cdef _process__copy_out_data(self, char mtype) cdef _process__copy_in(self, char mtype) cdef _process__copy_in_data(self, char mtype) cdef _parse_msg_authentication(self) cdef _parse_msg_parameter_status(self) cdef _parse_msg_notification(self) cdef _parse_msg_backend_key_data(self) cdef _parse_msg_ready_for_query(self) cdef _parse_data_msgs(self) cdef _parse_copy_data_msgs(self) cdef _parse_msg_error_response(self, is_error) cdef _parse_msg_command_complete(self) cdef _write_copy_data_msg(self, object data) cdef _write_copy_done_msg(self) cdef _write_copy_fail_msg(self, str cause) cdef _auth_password_message_cleartext(self) cdef _auth_password_message_md5(self, bytes salt) cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods) cdef _auth_password_message_sasl_continue(self, bytes server_response) cdef _auth_gss_init_gssapi(self) cdef _auth_gss_init_sspi(self, bint negotiate) cdef _auth_gss_get_service(self) cdef _auth_gss_step(self, bytes server_response) cdef _write(self, buf) cdef _writelines(self, list buffers) cdef _read_server_messages(self) cdef _push_result(self) cdef _reset_result(self) cdef _set_state(self, ProtocolState new_state) cdef _ensure_connected(self) cdef WriteBuffer _build_parse_message(self, str stmt_name, str query) cdef WriteBuffer _build_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data) cdef WriteBuffer _build_empty_bind_data(self) cdef WriteBuffer _build_execute_message(self, str portal_name, int32_t limit) cdef _connect(self) cdef _prepare_and_describe(self, str stmt_name, str query) cdef _send_parse_message(self, str stmt_name, str query) cdef _send_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, object bind_data, bint return_rows) cdef bint _bind_execute_many_more(self, bint first=*) cdef _bind_execute_many_fail(self, object error, bint first=*) cdef _bind(self, str portal_name, str stmt_name, WriteBuffer bind_data) cdef _execute(self, str portal_name, int32_t limit) cdef _close(self, str name, bint is_portal) cdef _simple_query(self, str query) cdef _copy_out(self, str copy_stmt) cdef _copy_in(self, str copy_stmt) cdef _terminate(self) cdef _decode_row(self, const char* buf, ssize_t buf_len) cdef _on_result(self) cdef _on_notification(self, pid, channel, payload) cdef _on_notice(self, parsed) cdef _set_server_parameter(self, name, val) cdef _on_connection_lost(self, exc) ================================================ FILE: asyncpg/protocol/coreproto.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import hashlib include "scram.pyx" AUTH_METHOD_NAME = { AUTH_REQUIRED_KERBEROS: 'kerberosv5', AUTH_REQUIRED_PASSWORD: 'password', AUTH_REQUIRED_PASSWORDMD5: 'md5', AUTH_REQUIRED_GSS: 'gss', AUTH_REQUIRED_SASL: 'scram-sha-256', AUTH_REQUIRED_SSPI: 'sspi', } cdef class CoreProtocol: def __init__(self, addr, con_params): self.address = addr # type of `con_params` is `_ConnectionParameters` self.buffer = ReadBuffer() self.user = con_params.user self.password = con_params.password self.auth_msg = None self.con_params = con_params self.con_status = CONNECTION_BAD self.state = PROTOCOL_IDLE self.xact_status = PQTRANS_IDLE self.encoding = 'utf-8' # type of `scram` is `SCRAMAuthentcation` self.scram = None # type of `gss_ctx` is `gssapi.SecurityContext` or # `sspilib.SecurityContext` self.gss_ctx = None self._reset_result() cpdef is_in_transaction(self): # PQTRANS_INTRANS = idle, within transaction block # PQTRANS_INERROR = idle, within failed transaction return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR) cdef _read_server_messages(self): cdef: char mtype ProtocolState state pgproto.take_message_method take_message = \ self.buffer.take_message pgproto.get_message_type_method get_message_type= \ self.buffer.get_message_type while take_message(self.buffer) == 1: mtype = get_message_type(self.buffer) state = self.state try: if mtype == b'S': # ParameterStatus self._parse_msg_parameter_status() elif mtype == b'A': # NotificationResponse self._parse_msg_notification() elif mtype == b'N': # 'N' - NoticeResponse self._on_notice(self._parse_msg_error_response(False)) elif state == PROTOCOL_AUTH: self._process__auth(mtype) elif state == PROTOCOL_PREPARE: self._process__prepare(mtype) elif state == PROTOCOL_BIND_EXECUTE: self._process__bind_execute(mtype) elif state == PROTOCOL_BIND_EXECUTE_MANY: self._process__bind_execute_many(mtype) elif state == PROTOCOL_EXECUTE: self._process__bind_execute(mtype) elif state == PROTOCOL_BIND: self._process__bind(mtype) elif state == PROTOCOL_CLOSE_STMT_PORTAL: self._process__close_stmt_portal(mtype) elif state == PROTOCOL_SIMPLE_QUERY: self._process__simple_query(mtype) elif state == PROTOCOL_COPY_OUT: self._process__copy_out(mtype) elif (state == PROTOCOL_COPY_OUT_DATA or state == PROTOCOL_COPY_OUT_DONE): self._process__copy_out_data(mtype) elif state == PROTOCOL_COPY_IN: self._process__copy_in(mtype) elif state == PROTOCOL_COPY_IN_DATA: self._process__copy_in_data(mtype) elif state == PROTOCOL_CANCELLED: # discard all messages until the sync message if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'Z': self._parse_msg_ready_for_query() self._push_result() else: self.buffer.discard_message() elif state == PROTOCOL_ERROR_CONSUME: # Error in protocol (on asyncpg side); # discard all messages until sync message if mtype == b'Z': # Sync point, self to push the result if self.result_type != RESULT_FAILED: self.result_type = RESULT_FAILED self.result = apg_exc.InternalClientError( 'unknown error in protocol implementation') self._parse_msg_ready_for_query() self._push_result() else: self.buffer.discard_message() elif state == PROTOCOL_TERMINATING: # The connection is being terminated. # discard all messages until connection # termination. self.buffer.discard_message() else: raise apg_exc.InternalClientError( f'cannot process message {chr(mtype)!r}: ' f'protocol is in an unexpected state {state!r}.') except Exception as ex: self.result_type = RESULT_FAILED self.result = ex if mtype == b'Z': self._push_result() else: self.state = PROTOCOL_ERROR_CONSUME finally: self.buffer.finish_message() cdef _process__auth(self, char mtype): if mtype == b'R': # Authentication... try: self._parse_msg_authentication() except Exception as ex: # Exception in authentication parsing code # is usually either malformed authentication data # or missing support for cryptographic primitives # in the hashlib module. self.result_type = RESULT_FAILED self.result = apg_exc.InternalClientError( f"unexpected error while performing authentication: {ex}") self.result.__cause__ = ex self.con_status = CONNECTION_BAD self._push_result() else: if self.result_type != RESULT_OK: self.con_status = CONNECTION_BAD self._push_result() elif self.auth_msg is not None: # Server wants us to send auth data, so do that. self._write(self.auth_msg) self.auth_msg = None elif mtype == b'K': # BackendKeyData self._parse_msg_backend_key_data() elif mtype == b'E': # ErrorResponse self.con_status = CONNECTION_BAD self._parse_msg_error_response(True) self._push_result() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self.con_status = CONNECTION_OK self._push_result() cdef _process__prepare(self, char mtype): if mtype == b't': # Parameters description self.result_param_desc = self.buffer.consume_message() elif mtype == b'1': # ParseComplete self.buffer.discard_message() elif mtype == b'T': # Row description self.result_row_desc = self.buffer.consume_message() self._push_result() elif mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) # we don't send a sync during the parse/describe sequence # but send a FLUSH instead. If an error happens we need to # send a SYNC explicitly in order to mark the end of the transaction. # this effectively clears the error and we then wait until we get a # ready for new query message self._write(SYNC_MESSAGE) self.state = PROTOCOL_ERROR_CONSUME elif mtype == b'n': # NoData self.buffer.discard_message() self._push_result() cdef _process__bind_execute(self, char mtype): if mtype == b'D': # DataRow self._parse_data_msgs() elif mtype == b's': # PortalSuspended self.buffer.discard_message() elif mtype == b'C': # CommandComplete self.result_execute_completed = True self._parse_msg_command_complete() elif mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'1': # ParseComplete, in case `_bind_execute()` is reparsing self.buffer.discard_message() elif mtype == b'2': # BindComplete self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() elif mtype == b'I': # EmptyQueryResponse self.buffer.discard_message() cdef _process__bind_execute_many(self, char mtype): cdef WriteBuffer buf if mtype == b'D': # DataRow self._parse_data_msgs() elif mtype == b's': # PortalSuspended self.buffer.discard_message() elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() elif mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'1': # ParseComplete, in case `_bind_execute_many()` is reparsing self.buffer.discard_message() elif mtype == b'2': # BindComplete self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() elif mtype == b'I': # EmptyQueryResponse self.buffer.discard_message() elif mtype == b'1': # ParseComplete self.buffer.discard_message() cdef _process__bind(self, char mtype): if mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'2': # BindComplete self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__close_stmt_portal(self, char mtype): if mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'3': # CloseComplete self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__simple_query(self, char mtype): if mtype in {b'D', b'I', b'T'}: # 'D' - DataRow # 'I' - EmptyQueryResponse # 'T' - RowDescription self.buffer.discard_message() elif mtype == b'E': # ErrorResponse self._parse_msg_error_response(True) elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() else: # We don't really care about COPY IN etc self.buffer.discard_message() cdef _process__copy_out(self, char mtype): if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'H': # CopyOutResponse self._set_state(PROTOCOL_COPY_OUT_DATA) self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__copy_out_data(self, char mtype): if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'd': # CopyData self._parse_copy_data_msgs() elif mtype == b'c': # CopyDone self.buffer.discard_message() self._set_state(PROTOCOL_COPY_OUT_DONE) elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__copy_in(self, char mtype): if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'G': # CopyInResponse self._set_state(PROTOCOL_COPY_IN_DATA) self.buffer.discard_message() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _process__copy_in_data(self, char mtype): if mtype == b'E': self._parse_msg_error_response(True) elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() elif mtype == b'Z': # ReadyForQuery self._parse_msg_ready_for_query() self._push_result() cdef _parse_msg_command_complete(self): cdef: const char* cbuf ssize_t cbuf_len cbuf = self.buffer.try_consume_message(&cbuf_len) if cbuf != NULL and cbuf_len > 0: msg = cpython.PyBytes_FromStringAndSize(cbuf, cbuf_len - 1) else: msg = self.buffer.read_null_str() self.result_status_msg = msg cdef _parse_copy_data_msgs(self): cdef: ReadBuffer buf = self.buffer self.result = buf.consume_messages(b'd') # By this point we have consumed all CopyData messages # in the inbound buffer. If there are no messages left # in the buffer, we need to push the accumulated data # out to the caller in anticipation of the new CopyData # batch. If there _are_ non-CopyData messages left, # we must not push the result here and let the # _process__copy_out_data subprotocol do the job. if not buf.take_message(): self._on_result() self.result = None else: # If there is a message in the buffer, put it back to # be processed by the next protocol iteration. buf.put_message() cdef _write_copy_data_msg(self, object data): cdef: WriteBuffer buf object mview Py_buffer *pybuf mview = cpythonx.PyMemoryView_GetContiguous( data, cpython.PyBUF_READ, b'C') try: pybuf = cpythonx.PyMemoryView_GET_BUFFER(mview) buf = WriteBuffer.new_message(b'd') buf.write_cstr(pybuf.buf, pybuf.len) buf.end_message() finally: mview.release() self._write(buf) cdef _write_copy_done_msg(self): cdef: WriteBuffer buf buf = WriteBuffer.new_message(b'c') buf.end_message() self._write(buf) cdef _write_copy_fail_msg(self, str cause): cdef: WriteBuffer buf buf = WriteBuffer.new_message(b'f') buf.write_str(cause or '', self.encoding) buf.end_message() self._write(buf) cdef _parse_data_msgs(self): cdef: ReadBuffer buf = self.buffer list rows decode_row_method decoder = self._decode_row pgproto.try_consume_message_method try_consume_message = \ buf.try_consume_message pgproto.take_message_type_method take_message_type = \ buf.take_message_type const char* cbuf ssize_t cbuf_len object row bytes mem if PG_DEBUG: if buf.get_message_type() != b'D': raise apg_exc.InternalClientError( '_parse_data_msgs: first message is not "D"') if self._discard_data: while take_message_type(buf, b'D'): buf.discard_message() return if PG_DEBUG: if type(self.result) is not list: raise apg_exc.InternalClientError( '_parse_data_msgs: result is not a list, but {!r}'. format(self.result)) rows = self.result while take_message_type(buf, b'D'): cbuf = try_consume_message(buf, &cbuf_len) if cbuf != NULL: row = decoder(self, cbuf, cbuf_len) else: mem = buf.consume_message() row = decoder( self, cpython.PyBytes_AS_STRING(mem), cpython.PyBytes_GET_SIZE(mem)) cpython.PyList_Append(rows, row) cdef _parse_msg_backend_key_data(self): self.backend_pid = self.buffer.read_int32() self.backend_secret = self.buffer.read_int32() cdef _parse_msg_parameter_status(self): name = self.buffer.read_null_str() name = name.decode(self.encoding) val = self.buffer.read_null_str() val = val.decode(self.encoding) self._set_server_parameter(name, val) cdef _parse_msg_notification(self): pid = self.buffer.read_int32() channel = self.buffer.read_null_str().decode(self.encoding) payload = self.buffer.read_null_str().decode(self.encoding) self._on_notification(pid, channel, payload) cdef _parse_msg_authentication(self): cdef: int32_t status bytes md5_salt list sasl_auth_methods list unsupported_sasl_auth_methods status = self.buffer.read_int32() if status == AUTH_SUCCESSFUL: # AuthenticationOk self.result_type = RESULT_OK elif status == AUTH_REQUIRED_PASSWORD: # AuthenticationCleartextPassword self.result_type = RESULT_OK self.auth_msg = self._auth_password_message_cleartext() elif status == AUTH_REQUIRED_PASSWORDMD5: # AuthenticationMD5Password # Note: MD5 salt is passed as a four-byte sequence md5_salt = self.buffer.read_bytes(4) self.auth_msg = self._auth_password_message_md5(md5_salt) elif status == AUTH_REQUIRED_SASL: # AuthenticationSASL # This requires making additional requests to the server in order # to follow the SCRAM protocol defined in RFC 5802. # get the SASL authentication methods that the server is providing sasl_auth_methods = [] unsupported_sasl_auth_methods = [] # determine if the advertised authentication methods are supported, # and if so, add them to the list auth_method = self.buffer.read_null_str() while auth_method: if auth_method in SCRAMAuthentication.AUTHENTICATION_METHODS: sasl_auth_methods.append(auth_method) else: unsupported_sasl_auth_methods.append(auth_method) auth_method = self.buffer.read_null_str() # if none of the advertised authentication methods are supported, # raise an error # otherwise, initialize the SASL authentication exchange if not sasl_auth_methods: unsupported_sasl_auth_methods = [m.decode("ascii") for m in unsupported_sasl_auth_methods] self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'unsupported SASL Authentication methods requested by the ' 'server: {!r}'.format( ", ".join(unsupported_sasl_auth_methods))) else: self.auth_msg = self._auth_password_message_sasl_initial( sasl_auth_methods) elif status == AUTH_SASL_CONTINUE: # AUTH_SASL_CONTINUE # this requeires sending the second part of the SASL exchange, where # the client parses information back from the server and determines # if this is valid. # The client builds a challenge response to the server server_response = self.buffer.consume_message() self.auth_msg = self._auth_password_message_sasl_continue( server_response) elif status == AUTH_SASL_FINAL: # AUTH_SASL_FINAL server_response = self.buffer.consume_message() if not self.scram.verify_server_final_message(server_response): self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'could not verify server signature for ' 'SCRAM authentciation: scram-sha-256', ) self.scram = None elif status in (AUTH_REQUIRED_GSS, AUTH_REQUIRED_SSPI): # AUTH_REQUIRED_SSPI is the same as AUTH_REQUIRED_GSS, except that # it uses protocol negotiation with SSPI clients. Both methods use # AUTH_REQUIRED_GSS_CONTINUE for subsequent authentication steps. if self.gss_ctx is not None: self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'duplicate GSSAPI/SSPI authentication request') else: if self.con_params.gsslib == 'gssapi': self._auth_gss_init_gssapi() else: self._auth_gss_init_sspi(status == AUTH_REQUIRED_SSPI) self.auth_msg = self._auth_gss_step(None) elif status == AUTH_REQUIRED_GSS_CONTINUE: server_response = self.buffer.consume_message() self.auth_msg = self._auth_gss_step(server_response) else: self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'unsupported authentication method requested by the ' 'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status))) if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL, AUTH_REQUIRED_GSS_CONTINUE): self.buffer.discard_message() cdef _auth_password_message_cleartext(self): cdef: WriteBuffer msg msg = WriteBuffer.new_message(b'p') msg.write_bytestring(self.password.encode(self.encoding)) msg.end_message() return msg cdef _auth_password_message_md5(self, bytes salt): cdef: WriteBuffer msg msg = WriteBuffer.new_message(b'p') # 'md5' + md5(md5(password + username) + salt)) userpass = (self.password or '') + (self.user or '') md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest() md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest() msg.write_bytestring(b'md5' + md5_2.encode('ascii')) msg.end_message() return msg cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods): cdef: WriteBuffer msg # use the first supported advertized mechanism self.scram = SCRAMAuthentication(sasl_auth_methods[0]) # this involves a call and response with the server msg = WriteBuffer.new_message(b'p') msg.write_bytes(self.scram.create_client_first_message(self.user or '')) msg.end_message() return msg cdef _auth_password_message_sasl_continue(self, bytes server_response): cdef: WriteBuffer msg # determine if there is a valid server response self.scram.parse_server_first_message(server_response) # this involves a call and response with the server msg = WriteBuffer.new_message(b'p') client_final_message = self.scram.create_client_final_message( self.password or '') msg.write_bytes(client_final_message) msg.end_message() return msg cdef _auth_gss_init_gssapi(self): try: import gssapi except ModuleNotFoundError: raise apg_exc.InterfaceError( 'gssapi module not found; please install asyncpg[gssauth] to ' 'use asyncpg with Kerberos/GSSAPI/SSPI authentication' ) from None service_name, host = self._auth_gss_get_service() self.gss_ctx = gssapi.SecurityContext( name=gssapi.Name( f'{service_name}@{host}', gssapi.NameType.hostbased_service), usage='initiate') cdef _auth_gss_init_sspi(self, bint negotiate): try: import sspilib except ModuleNotFoundError: raise apg_exc.InterfaceError( 'sspilib module not found; please install asyncpg[gssauth] to ' 'use asyncpg with Kerberos/GSSAPI/SSPI authentication' ) from None service_name, host = self._auth_gss_get_service() self.gss_ctx = sspilib.ClientSecurityContext( target_name=f'{service_name}/{host}', credential=sspilib.UserCredential( protocol='Negotiate' if negotiate else 'Kerberos')) cdef _auth_gss_get_service(self): service_name = self.con_params.krbsrvname or 'postgres' if isinstance(self.address, str): raise apg_exc.InternalClientError( 'GSSAPI/SSPI authentication is only supported for TCP/IP ' 'connections') return service_name, self.address[0] cdef _auth_gss_step(self, bytes server_response): cdef: WriteBuffer msg token = self.gss_ctx.step(server_response) if not token: self.gss_ctx = None return None msg = WriteBuffer.new_message(b'p') msg.write_bytes(token) msg.end_message() return msg cdef _parse_msg_ready_for_query(self): cdef char status = self.buffer.read_byte() if status == b'I': self.xact_status = PQTRANS_IDLE elif status == b'T': self.xact_status = PQTRANS_INTRANS elif status == b'E': self.xact_status = PQTRANS_INERROR else: self.xact_status = PQTRANS_UNKNOWN cdef _parse_msg_error_response(self, is_error): cdef: char code bytes message dict parsed = {} while True: code = self.buffer.read_byte() if code == 0: break message = self.buffer.read_null_str() parsed[chr(code)] = message.decode() if is_error: self.result_type = RESULT_FAILED self.result = parsed else: return parsed cdef _push_result(self): try: self._on_result() finally: self._set_state(PROTOCOL_IDLE) self._reset_result() cdef _reset_result(self): self.result_type = RESULT_OK self.result = None self.result_param_desc = None self.result_row_desc = None self.result_status_msg = None self.result_execute_completed = False self._discard_data = False # executemany support data self._execute_iter = None self._execute_portal_name = None self._execute_stmt_name = None cdef _set_state(self, ProtocolState new_state): if new_state == PROTOCOL_IDLE: if self.state == PROTOCOL_FAILED: raise apg_exc.InternalClientError( 'cannot switch to "idle" state; ' 'protocol is in the "failed" state') elif self.state == PROTOCOL_IDLE: pass else: self.state = new_state elif new_state == PROTOCOL_FAILED: self.state = PROTOCOL_FAILED elif new_state == PROTOCOL_CANCELLED: self.state = PROTOCOL_CANCELLED elif new_state == PROTOCOL_TERMINATING: self.state = PROTOCOL_TERMINATING else: if self.state == PROTOCOL_IDLE: self.state = new_state elif (self.state == PROTOCOL_COPY_OUT and new_state == PROTOCOL_COPY_OUT_DATA): self.state = new_state elif (self.state == PROTOCOL_COPY_OUT_DATA and new_state == PROTOCOL_COPY_OUT_DONE): self.state = new_state elif (self.state == PROTOCOL_COPY_IN and new_state == PROTOCOL_COPY_IN_DATA): self.state = new_state elif self.state == PROTOCOL_FAILED: raise apg_exc.InternalClientError( 'cannot switch to state {}; ' 'protocol is in the "failed" state'.format(new_state)) else: raise apg_exc.InternalClientError( 'cannot switch to state {}; ' 'another operation ({}) is in progress'.format( new_state, self.state)) cdef _ensure_connected(self): if self.con_status != CONNECTION_OK: raise apg_exc.InternalClientError('not connected') cdef WriteBuffer _build_parse_message(self, str stmt_name, str query): cdef WriteBuffer buf buf = WriteBuffer.new_message(b'P') buf.write_str(stmt_name, self.encoding) buf.write_str(query, self.encoding) buf.write_int16(0) buf.end_message() return buf cdef WriteBuffer _build_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data): cdef WriteBuffer buf buf = WriteBuffer.new_message(b'B') buf.write_str(portal_name, self.encoding) buf.write_str(stmt_name, self.encoding) # Arguments buf.write_buffer(bind_data) buf.end_message() return buf cdef WriteBuffer _build_empty_bind_data(self): cdef WriteBuffer buf buf = WriteBuffer.new() buf.write_int16(0) # The number of parameter format codes buf.write_int16(0) # The number of parameter values buf.write_int16(0) # The number of result-column format codes return buf cdef WriteBuffer _build_execute_message(self, str portal_name, int32_t limit): cdef WriteBuffer buf buf = WriteBuffer.new_message(b'E') buf.write_str(portal_name, self.encoding) # name of the portal buf.write_int32(limit) # number of rows to return; 0 - all buf.end_message() return buf # API for subclasses cdef _connect(self): cdef: WriteBuffer buf WriteBuffer outbuf if self.con_status != CONNECTION_BAD: raise apg_exc.InternalClientError('already connected') self._set_state(PROTOCOL_AUTH) self.con_status = CONNECTION_STARTED # Assemble a startup message buf = WriteBuffer() # protocol version buf.write_int16(3) buf.write_int16(0) buf.write_bytestring(b'client_encoding') buf.write_bytestring("'{}'".format(self.encoding).encode('ascii')) buf.write_str('user', self.encoding) buf.write_str(self.con_params.user, self.encoding) buf.write_str('database', self.encoding) buf.write_str(self.con_params.database, self.encoding) if self.con_params.server_settings is not None: for k, v in self.con_params.server_settings.items(): buf.write_str(k, self.encoding) buf.write_str(v, self.encoding) buf.write_bytestring(b'') # Send the buffer outbuf = WriteBuffer() outbuf.write_int32(buf.len() + 4) outbuf.write_buffer(buf) self._write(outbuf) cdef _send_parse_message(self, str stmt_name, str query): cdef: WriteBuffer msg self._ensure_connected() msg = self._build_parse_message(stmt_name, query) self._write(msg) cdef _prepare_and_describe(self, str stmt_name, str query): cdef: WriteBuffer packet WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_PREPARE) packet = self._build_parse_message(stmt_name, query) buf = WriteBuffer.new_message(b'D') buf.write_byte(b'S') buf.write_str(stmt_name, self.encoding) buf.end_message() packet.write_buffer(buf) packet.write_bytes(FLUSH_MESSAGE) self._write(packet) cdef _send_bind_message(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit): cdef: WriteBuffer packet WriteBuffer buf buf = self._build_bind_message(portal_name, stmt_name, bind_data) packet = buf buf = self._build_execute_message(portal_name, limit) packet.write_buffer(buf) packet.write_bytes(SYNC_MESSAGE) self._write(packet) cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE) self.result = [] self._send_bind_message(portal_name, stmt_name, bind_data, limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, object bind_data, bint return_rows): self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE_MANY) self.result = [] if return_rows else None self._discard_data = not return_rows self._execute_iter = bind_data self._execute_portal_name = portal_name self._execute_stmt_name = stmt_name return self._bind_execute_many_more(True) cdef bint _bind_execute_many_more(self, bint first=False): cdef: WriteBuffer packet WriteBuffer buf list buffers = [] # as we keep sending, the server may return an error early if self.result_type == RESULT_FAILED: self._write(SYNC_MESSAGE) return False # collect up to four 32KB buffers to send # https://github.com/MagicStack/asyncpg/pull/289#issuecomment-391215051 while len(buffers) < _EXECUTE_MANY_BUF_NUM: packet = WriteBuffer.new() # fill one 32KB buffer while packet.len() < _EXECUTE_MANY_BUF_SIZE: try: # grab one item from the input buf = next(self._execute_iter) # reached the end of the input except StopIteration: if first: # if we never send anything, simply set the result self._push_result() else: # otherwise, append SYNC and send the buffers packet.write_bytes(SYNC_MESSAGE) buffers.append(memoryview(packet)) self._writelines(buffers) return False # error in input, give up the buffers and cleanup except Exception as ex: self._bind_execute_many_fail(ex, first) return False # all good, write to the buffer first = False packet.write_buffer( self._build_bind_message( self._execute_portal_name, self._execute_stmt_name, buf, ) ) packet.write_buffer( self._build_execute_message(self._execute_portal_name, 0, ) ) # collected one buffer buffers.append(memoryview(packet)) # write to the wire, and signal the caller for more to send self._writelines(buffers) return True cdef _bind_execute_many_fail(self, object error, bint first=False): cdef WriteBuffer buf self.result_type = RESULT_FAILED self.result = error if first: self._push_result() elif self.is_in_transaction(): # we're in an explicit transaction, just SYNC self._write(SYNC_MESSAGE) else: # In an implicit transaction, if `ignore_till_sync` is set, # `ROLLBACK` will be ignored and `Sync` will restore the state; # or the transaction will be rolled back with a warning saying # that there was no transaction, but rollback is done anyway, # so we could safely ignore this warning. # GOTCHA: cannot use simple query message here, because it is # ignored if `ignore_till_sync` is set. buf = self._build_parse_message('', 'ROLLBACK') buf.write_buffer(self._build_bind_message( '', '', self._build_empty_bind_data())) buf.write_buffer(self._build_execute_message('', 0)) buf.write_bytes(SYNC_MESSAGE) self._write(buf) cdef _execute(self, str portal_name, int32_t limit): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_EXECUTE) self.result = [] buf = self._build_execute_message(portal_name, limit) buf.write_bytes(SYNC_MESSAGE) self._write(buf) cdef _bind(self, str portal_name, str stmt_name, WriteBuffer bind_data): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_BIND) buf = self._build_bind_message(portal_name, stmt_name, bind_data) buf.write_bytes(SYNC_MESSAGE) self._write(buf) cdef _close(self, str name, bint is_portal): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_CLOSE_STMT_PORTAL) buf = WriteBuffer.new_message(b'C') if is_portal: buf.write_byte(b'P') else: buf.write_byte(b'S') buf.write_str(name, self.encoding) buf.end_message() buf.write_bytes(SYNC_MESSAGE) self._write(buf) cdef _simple_query(self, str query): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_SIMPLE_QUERY) buf = WriteBuffer.new_message(b'Q') buf.write_str(query, self.encoding) buf.end_message() self._write(buf) cdef _copy_out(self, str copy_stmt): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_COPY_OUT) # Send the COPY .. TO STDOUT using the SimpleQuery protocol. buf = WriteBuffer.new_message(b'Q') buf.write_str(copy_stmt, self.encoding) buf.end_message() self._write(buf) cdef _copy_in(self, str copy_stmt): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_COPY_IN) buf = WriteBuffer.new_message(b'Q') buf.write_str(copy_stmt, self.encoding) buf.end_message() self._write(buf) cdef _terminate(self): cdef WriteBuffer buf self._ensure_connected() self._set_state(PROTOCOL_TERMINATING) buf = WriteBuffer.new_message(b'X') buf.end_message() self._write(buf) cdef _write(self, buf): raise NotImplementedError cdef _writelines(self, list buffers): raise NotImplementedError cdef _decode_row(self, const char* buf, ssize_t buf_len): pass cdef _set_server_parameter(self, name, val): pass cdef _on_result(self): pass cdef _on_notice(self, parsed): pass cdef _on_notification(self, pid, channel, payload): pass cdef _on_connection_lost(self, exc): pass SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message()) FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message()) ================================================ FILE: asyncpg/protocol/cpythonx.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 cdef extern from "Python.h": int PyByteArray_Check(object) int PyMemoryView_Check(object) Py_buffer *PyMemoryView_GET_BUFFER(object) object PyMemoryView_GetContiguous(object, int buffertype, char order) Py_UCS4* PyUnicode_AsUCS4Copy(object) except NULL object PyUnicode_FromKindAndData( int kind, const void *buffer, Py_ssize_t size) int PyUnicode_4BYTE_KIND ================================================ FILE: asyncpg/protocol/encodings.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 '''Map PostgreSQL encoding names to Python encoding names https://www.postgresql.org/docs/current/static/multibyte.html#CHARSET-TABLE ''' ENCODINGS_MAP = { 'abc': 'cp1258', 'alt': 'cp866', 'euc_cn': 'euccn', 'euc_jp': 'eucjp', 'euc_kr': 'euckr', 'koi8r': 'koi8_r', 'koi8u': 'koi8_u', 'shift_jis_2004': 'euc_jis_2004', 'sjis': 'shift_jis', 'sql_ascii': 'ascii', 'vscii': 'cp1258', 'tcvn': 'cp1258', 'tcvn5712': 'cp1258', 'unicode': 'utf_8', 'win': 'cp1521', 'win1250': 'cp1250', 'win1251': 'cp1251', 'win1252': 'cp1252', 'win1253': 'cp1253', 'win1254': 'cp1254', 'win1255': 'cp1255', 'win1256': 'cp1256', 'win1257': 'cp1257', 'win1258': 'cp1258', 'win866': 'cp866', 'win874': 'cp874', 'win932': 'cp932', 'win936': 'cp936', 'win949': 'cp949', 'win950': 'cp950', 'windows1250': 'cp1250', 'windows1251': 'cp1251', 'windows1252': 'cp1252', 'windows1253': 'cp1253', 'windows1254': 'cp1254', 'windows1255': 'cp1255', 'windows1256': 'cp1256', 'windows1257': 'cp1257', 'windows1258': 'cp1258', 'windows866': 'cp866', 'windows874': 'cp874', 'windows932': 'cp932', 'windows936': 'cp936', 'windows949': 'cp949', 'windows950': 'cp950', } cdef get_python_encoding(pg_encoding): return ENCODINGS_MAP.get(pg_encoding.lower(), pg_encoding.lower()) ================================================ FILE: asyncpg/protocol/pgtypes.pxi ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 # GENERATED FROM pg_catalog.pg_type # DO NOT MODIFY, use tools/generate_type_map.py to update DEF INVALIDOID = 0 DEF MAXBUILTINOID = 9999 DEF MAXSUPPORTEDOID = 5080 DEF BOOLOID = 16 DEF BYTEAOID = 17 DEF CHAROID = 18 DEF NAMEOID = 19 DEF INT8OID = 20 DEF INT2OID = 21 DEF INT4OID = 23 DEF REGPROCOID = 24 DEF TEXTOID = 25 DEF OIDOID = 26 DEF TIDOID = 27 DEF XIDOID = 28 DEF CIDOID = 29 DEF PG_DDL_COMMANDOID = 32 DEF JSONOID = 114 DEF XMLOID = 142 DEF PG_NODE_TREEOID = 194 DEF SMGROID = 210 DEF TABLE_AM_HANDLEROID = 269 DEF INDEX_AM_HANDLEROID = 325 DEF POINTOID = 600 DEF LSEGOID = 601 DEF PATHOID = 602 DEF BOXOID = 603 DEF POLYGONOID = 604 DEF LINEOID = 628 DEF CIDROID = 650 DEF FLOAT4OID = 700 DEF FLOAT8OID = 701 DEF ABSTIMEOID = 702 DEF RELTIMEOID = 703 DEF TINTERVALOID = 704 DEF UNKNOWNOID = 705 DEF CIRCLEOID = 718 DEF MACADDR8OID = 774 DEF MONEYOID = 790 DEF MACADDROID = 829 DEF INETOID = 869 DEF _TEXTOID = 1009 DEF _OIDOID = 1028 DEF ACLITEMOID = 1033 DEF BPCHAROID = 1042 DEF VARCHAROID = 1043 DEF DATEOID = 1082 DEF TIMEOID = 1083 DEF TIMESTAMPOID = 1114 DEF TIMESTAMPTZOID = 1184 DEF INTERVALOID = 1186 DEF TIMETZOID = 1266 DEF BITOID = 1560 DEF VARBITOID = 1562 DEF NUMERICOID = 1700 DEF REFCURSOROID = 1790 DEF REGPROCEDUREOID = 2202 DEF REGOPEROID = 2203 DEF REGOPERATOROID = 2204 DEF REGCLASSOID = 2205 DEF REGTYPEOID = 2206 DEF RECORDOID = 2249 DEF CSTRINGOID = 2275 DEF ANYOID = 2276 DEF ANYARRAYOID = 2277 DEF VOIDOID = 2278 DEF TRIGGEROID = 2279 DEF LANGUAGE_HANDLEROID = 2280 DEF INTERNALOID = 2281 DEF OPAQUEOID = 2282 DEF ANYELEMENTOID = 2283 DEF ANYNONARRAYOID = 2776 DEF UUIDOID = 2950 DEF TXID_SNAPSHOTOID = 2970 DEF FDW_HANDLEROID = 3115 DEF PG_LSNOID = 3220 DEF TSM_HANDLEROID = 3310 DEF PG_NDISTINCTOID = 3361 DEF PG_DEPENDENCIESOID = 3402 DEF ANYENUMOID = 3500 DEF TSVECTOROID = 3614 DEF TSQUERYOID = 3615 DEF GTSVECTOROID = 3642 DEF REGCONFIGOID = 3734 DEF REGDICTIONARYOID = 3769 DEF JSONBOID = 3802 DEF ANYRANGEOID = 3831 DEF EVENT_TRIGGEROID = 3838 DEF JSONPATHOID = 4072 DEF REGNAMESPACEOID = 4089 DEF REGROLEOID = 4096 DEF REGCOLLATIONOID = 4191 DEF ANYMULTIRANGEOID = 4537 DEF ANYCOMPATIBLEMULTIRANGEOID = 4538 DEF PG_BRIN_BLOOM_SUMMARYOID = 4600 DEF PG_BRIN_MINMAX_MULTI_SUMMARYOID = 4601 DEF PG_MCV_LISTOID = 5017 DEF PG_SNAPSHOTOID = 5038 DEF XID8OID = 5069 DEF ANYCOMPATIBLEOID = 5077 DEF ANYCOMPATIBLEARRAYOID = 5078 DEF ANYCOMPATIBLENONARRAYOID = 5079 DEF ANYCOMPATIBLERANGEOID = 5080 ARRAY_TYPES = {_TEXTOID, _OIDOID} BUILTIN_TYPE_OID_MAP = { ABSTIMEOID: 'abstime', ACLITEMOID: 'aclitem', ANYARRAYOID: 'anyarray', ANYCOMPATIBLEARRAYOID: 'anycompatiblearray', ANYCOMPATIBLEMULTIRANGEOID: 'anycompatiblemultirange', ANYCOMPATIBLENONARRAYOID: 'anycompatiblenonarray', ANYCOMPATIBLEOID: 'anycompatible', ANYCOMPATIBLERANGEOID: 'anycompatiblerange', ANYELEMENTOID: 'anyelement', ANYENUMOID: 'anyenum', ANYMULTIRANGEOID: 'anymultirange', ANYNONARRAYOID: 'anynonarray', ANYOID: 'any', ANYRANGEOID: 'anyrange', BITOID: 'bit', BOOLOID: 'bool', BOXOID: 'box', BPCHAROID: 'bpchar', BYTEAOID: 'bytea', CHAROID: 'char', CIDOID: 'cid', CIDROID: 'cidr', CIRCLEOID: 'circle', CSTRINGOID: 'cstring', DATEOID: 'date', EVENT_TRIGGEROID: 'event_trigger', FDW_HANDLEROID: 'fdw_handler', FLOAT4OID: 'float4', FLOAT8OID: 'float8', GTSVECTOROID: 'gtsvector', INDEX_AM_HANDLEROID: 'index_am_handler', INETOID: 'inet', INT2OID: 'int2', INT4OID: 'int4', INT8OID: 'int8', INTERNALOID: 'internal', INTERVALOID: 'interval', JSONBOID: 'jsonb', JSONOID: 'json', JSONPATHOID: 'jsonpath', LANGUAGE_HANDLEROID: 'language_handler', LINEOID: 'line', LSEGOID: 'lseg', MACADDR8OID: 'macaddr8', MACADDROID: 'macaddr', MONEYOID: 'money', NAMEOID: 'name', NUMERICOID: 'numeric', OIDOID: 'oid', OPAQUEOID: 'opaque', PATHOID: 'path', PG_BRIN_BLOOM_SUMMARYOID: 'pg_brin_bloom_summary', PG_BRIN_MINMAX_MULTI_SUMMARYOID: 'pg_brin_minmax_multi_summary', PG_DDL_COMMANDOID: 'pg_ddl_command', PG_DEPENDENCIESOID: 'pg_dependencies', PG_LSNOID: 'pg_lsn', PG_MCV_LISTOID: 'pg_mcv_list', PG_NDISTINCTOID: 'pg_ndistinct', PG_NODE_TREEOID: 'pg_node_tree', PG_SNAPSHOTOID: 'pg_snapshot', POINTOID: 'point', POLYGONOID: 'polygon', RECORDOID: 'record', REFCURSOROID: 'refcursor', REGCLASSOID: 'regclass', REGCOLLATIONOID: 'regcollation', REGCONFIGOID: 'regconfig', REGDICTIONARYOID: 'regdictionary', REGNAMESPACEOID: 'regnamespace', REGOPERATOROID: 'regoperator', REGOPEROID: 'regoper', REGPROCEDUREOID: 'regprocedure', REGPROCOID: 'regproc', REGROLEOID: 'regrole', REGTYPEOID: 'regtype', RELTIMEOID: 'reltime', SMGROID: 'smgr', TABLE_AM_HANDLEROID: 'table_am_handler', TEXTOID: 'text', TIDOID: 'tid', TIMEOID: 'time', TIMESTAMPOID: 'timestamp', TIMESTAMPTZOID: 'timestamptz', TIMETZOID: 'timetz', TINTERVALOID: 'tinterval', TRIGGEROID: 'trigger', TSM_HANDLEROID: 'tsm_handler', TSQUERYOID: 'tsquery', TSVECTOROID: 'tsvector', TXID_SNAPSHOTOID: 'txid_snapshot', UNKNOWNOID: 'unknown', UUIDOID: 'uuid', VARBITOID: 'varbit', VARCHAROID: 'varchar', VOIDOID: 'void', XID8OID: 'xid8', XIDOID: 'xid', XMLOID: 'xml', _OIDOID: 'oid[]', _TEXTOID: 'text[]' } BUILTIN_TYPE_NAME_MAP = {v: k for k, v in BUILTIN_TYPE_OID_MAP.items()} BUILTIN_TYPE_NAME_MAP['smallint'] = \ BUILTIN_TYPE_NAME_MAP['int2'] BUILTIN_TYPE_NAME_MAP['int'] = \ BUILTIN_TYPE_NAME_MAP['int4'] BUILTIN_TYPE_NAME_MAP['integer'] = \ BUILTIN_TYPE_NAME_MAP['int4'] BUILTIN_TYPE_NAME_MAP['bigint'] = \ BUILTIN_TYPE_NAME_MAP['int8'] BUILTIN_TYPE_NAME_MAP['decimal'] = \ BUILTIN_TYPE_NAME_MAP['numeric'] BUILTIN_TYPE_NAME_MAP['real'] = \ BUILTIN_TYPE_NAME_MAP['float4'] BUILTIN_TYPE_NAME_MAP['double precision'] = \ BUILTIN_TYPE_NAME_MAP['float8'] BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \ BUILTIN_TYPE_NAME_MAP['timestamptz'] BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \ BUILTIN_TYPE_NAME_MAP['timestamp'] BUILTIN_TYPE_NAME_MAP['time with timezone'] = \ BUILTIN_TYPE_NAME_MAP['timetz'] BUILTIN_TYPE_NAME_MAP['time without timezone'] = \ BUILTIN_TYPE_NAME_MAP['time'] BUILTIN_TYPE_NAME_MAP['char'] = \ BUILTIN_TYPE_NAME_MAP['bpchar'] BUILTIN_TYPE_NAME_MAP['character'] = \ BUILTIN_TYPE_NAME_MAP['bpchar'] BUILTIN_TYPE_NAME_MAP['character varying'] = \ BUILTIN_TYPE_NAME_MAP['varchar'] BUILTIN_TYPE_NAME_MAP['bit varying'] = \ BUILTIN_TYPE_NAME_MAP['varbit'] ================================================ FILE: asyncpg/protocol/prepared_stmt.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 cdef class PreparedStatementState: cdef: readonly str name readonly str query readonly bint closed readonly bint prepared readonly int refs readonly type record_class readonly bint ignore_custom_codec list row_desc list parameters_desc ConnectionSettings settings int16_t args_num bint have_text_args tuple args_codecs int16_t cols_num object cols_desc bint have_text_cols tuple rows_codecs cdef _encode_bind_msg(self, args, int seqno = ?) cpdef _init_codecs(self) cdef _ensure_rows_decoder(self) cdef _ensure_args_encoder(self) cdef _set_row_desc(self, object desc) cdef _set_args_desc(self, object desc) cdef _decode_row(self, const char* cbuf, ssize_t buf_len) ================================================ FILE: asyncpg/protocol/prepared_stmt.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from asyncpg import exceptions @cython.final cdef class PreparedStatementState: def __cinit__( self, str name, str query, BaseProtocol protocol, type record_class, bint ignore_custom_codec ): self.name = name self.query = query self.settings = protocol.settings self.row_desc = self.parameters_desc = None self.args_codecs = self.rows_codecs = None self.args_num = self.cols_num = 0 self.cols_desc = None self.closed = False self.prepared = True self.refs = 0 self.record_class = record_class self.ignore_custom_codec = ignore_custom_codec def _get_parameters(self): cdef Codec codec result = [] for oid in self.parameters_desc: codec = self.settings.get_data_codec(oid) if codec is None: raise exceptions.InternalClientError( 'missing codec information for OID {}'.format(oid)) result.append(apg_types.Type( oid, codec.name, codec.kind, codec.schema)) return tuple(result) def _get_attributes(self): cdef Codec codec if not self.row_desc: return () result = [] for d in self.row_desc: name = d[0] oid = d[3] codec = self.settings.get_data_codec(oid) if codec is None: raise exceptions.InternalClientError( 'missing codec information for OID {}'.format(oid)) name = name.decode(self.settings._encoding) result.append( apg_types.Attribute(name, apg_types.Type(oid, codec.name, codec.kind, codec.schema))) return tuple(result) def _init_types(self): cdef: Codec codec set missing = set() if self.parameters_desc: for p_oid in self.parameters_desc: codec = self.settings.get_data_codec(p_oid) if codec is None or not codec.has_encoder(): missing.add(p_oid) if self.row_desc: for rdesc in self.row_desc: codec = self.settings.get_data_codec((rdesc[3])) if codec is None or not codec.has_decoder(): missing.add(rdesc[3]) return missing cpdef _init_codecs(self): self._ensure_args_encoder() self._ensure_rows_decoder() def attach(self): self.refs += 1 def detach(self): self.refs -= 1 def mark_closed(self): self.closed = True def mark_unprepared(self): if self.name: raise exceptions.InternalClientError( "named prepared statements cannot be marked unprepared") self.prepared = False cdef _encode_bind_msg(self, args, int seqno = -1): cdef: int idx WriteBuffer writer Codec codec if not cpython.PySequence_Check(args): if seqno >= 0: raise exceptions.DataError( f'invalid input in executemany() argument sequence ' f'element #{seqno}: expected a sequence, got ' f'{type(args).__name__}' ) else: # Non executemany() callers do not pass user input directly, # so bad input is a bug. raise exceptions.InternalClientError( f'Bind: expected a sequence, got {type(args).__name__}') if len(args) > 32767: raise exceptions.InterfaceError( 'the number of query arguments cannot exceed 32767') writer = WriteBuffer.new() num_args_passed = len(args) if self.args_num != num_args_passed: hint = 'Check the query against the passed list of arguments.' if self.args_num == 0: # If the server was expecting zero arguments, it is likely # that the user tried to parametrize a statement that does # not support parameters. hint += (r' Note that parameters are supported only in' r' SELECT, INSERT, UPDATE, DELETE, MERGE and VALUES' r' statements, and will *not* work in statements ' r' like CREATE VIEW or DECLARE CURSOR.') raise exceptions.InterfaceError( 'the server expects {x} argument{s} for this query, ' '{y} {w} passed'.format( x=self.args_num, s='s' if self.args_num != 1 else '', y=num_args_passed, w='was' if num_args_passed == 1 else 'were'), hint=hint) if self.have_text_args: writer.write_int16(self.args_num) for idx in range(self.args_num): codec = (self.args_codecs[idx]) writer.write_int16(codec.format) else: # All arguments are in binary format writer.write_int32(0x00010001) writer.write_int16(self.args_num) for idx in range(self.args_num): arg = args[idx] if arg is None: writer.write_int32(-1) else: codec = (self.args_codecs[idx]) try: codec.encode(self.settings, writer, arg) except (AssertionError, exceptions.InternalClientError): # These are internal errors and should raise as-is. raise except exceptions.InterfaceError as e: # This is already a descriptive error, but annotate # with argument name for clarity. pos = f'${idx + 1}' if seqno >= 0: pos = ( f'{pos} in element #{seqno} of' f' executemany() sequence' ) raise e.with_msg( f'query argument {pos}: {e.args[0]}' ) from None except Exception as e: # Everything else is assumed to be an encoding error # due to invalid input. pos = f'${idx + 1}' if seqno >= 0: pos = ( f'{pos} in element #{seqno} of' f' executemany() sequence' ) value_repr = repr(arg) if len(value_repr) > 40: value_repr = value_repr[:40] + '...' raise exceptions.DataError( f'invalid input for query argument' f' {pos}: {value_repr} ({e})' ) from e if self.have_text_cols: writer.write_int16(self.cols_num) for idx in range(self.cols_num): codec = (self.rows_codecs[idx]) writer.write_int16(codec.format) else: # All columns are in binary format writer.write_int32(0x00010001) return writer cdef _ensure_rows_decoder(self): cdef: list cols_names object cols_mapping tuple row uint32_t oid Codec codec list codecs if self.cols_desc is not None: return if self.cols_num == 0: self.cols_desc = RecordDescriptor({}, ()) return cols_mapping = collections.OrderedDict() cols_names = [] codecs = [] for i from 0 <= i < self.cols_num: row = self.row_desc[i] col_name = row[0].decode(self.settings._encoding) cols_mapping[col_name] = i cols_names.append(col_name) oid = row[3] codec = self.settings.get_data_codec( oid, ignore_custom_codec=self.ignore_custom_codec) if codec is None or not codec.has_decoder(): raise exceptions.InternalClientError( 'no decoder for OID {}'.format(oid)) if not codec.is_binary(): self.have_text_cols = True codecs.append(codec) self.cols_desc = RecordDescriptor( cols_mapping, tuple(cols_names)) self.rows_codecs = tuple(codecs) cdef _ensure_args_encoder(self): cdef: uint32_t p_oid Codec codec list codecs = [] if self.args_num == 0 or self.args_codecs is not None: return for i from 0 <= i < self.args_num: p_oid = self.parameters_desc[i] codec = self.settings.get_data_codec( p_oid, ignore_custom_codec=self.ignore_custom_codec) if codec is None or not codec.has_encoder(): raise exceptions.InternalClientError( 'no encoder for OID {}'.format(p_oid)) if codec.type not in {}: self.have_text_args = True codecs.append(codec) self.args_codecs = tuple(codecs) cdef _set_row_desc(self, object desc): self.row_desc = _decode_row_desc(desc) self.cols_num = (len(self.row_desc)) cdef _set_args_desc(self, object desc): self.parameters_desc = _decode_parameters_desc(desc) self.args_num = (len(self.parameters_desc)) cdef _decode_row(self, const char* cbuf, ssize_t buf_len): cdef: Codec codec int16_t fnum int32_t flen object dec_row tuple rows_codecs = self.rows_codecs ConnectionSettings settings = self.settings int32_t i FRBuffer rbuf ssize_t bl frb_init(&rbuf, cbuf, buf_len) fnum = hton.unpack_int16(frb_read(&rbuf, 2)) if fnum != self.cols_num: raise exceptions.ProtocolError( 'the number of columns in the result row ({}) is ' 'different from what was described ({})'.format( fnum, self.cols_num)) dec_row = self.cols_desc.make_record(self.record_class, fnum) for i in range(fnum): flen = hton.unpack_int32(frb_read(&rbuf, 4)) if flen == -1: val = None else: # Clamp buffer size to that of the reported field length # to make sure that codecs can rely on read_all() working # properly. bl = frb_get_len(&rbuf) if flen > bl: frb_check(&rbuf, flen) frb_set_len(&rbuf, flen) codec = cpython.PyTuple_GET_ITEM(rows_codecs, i) val = codec.decode(settings, &rbuf) if frb_get_len(&rbuf) != 0: raise BufferError( 'unexpected trailing {} bytes in buffer'.format( frb_get_len(&rbuf))) frb_set_len(&rbuf, bl - flen) cpython.Py_INCREF(val) recordcapi.ApgRecord_SET_ITEM(dec_row, i, val) if frb_get_len(&rbuf) != 0: raise BufferError('unexpected trailing {} bytes in buffer'.format( frb_get_len(&rbuf))) return dec_row cdef _decode_parameters_desc(object desc): cdef: ReadBuffer reader int16_t nparams uint32_t p_oid list result = [] reader = ReadBuffer.new_message_parser(desc) nparams = reader.read_int16() for i from 0 <= i < nparams: p_oid = reader.read_int32() result.append(p_oid) return result cdef _decode_row_desc(object desc): cdef: ReadBuffer reader int16_t nfields bytes f_name uint32_t f_table_oid int16_t f_column_num uint32_t f_dt_oid int16_t f_dt_size int32_t f_dt_mod int16_t f_format list result reader = ReadBuffer.new_message_parser(desc) nfields = reader.read_int16() result = [] for i from 0 <= i < nfields: f_name = reader.read_null_str() f_table_oid = reader.read_int32() f_column_num = reader.read_int16() f_dt_oid = reader.read_int32() f_dt_size = reader.read_int16() f_dt_mod = reader.read_int32() f_format = reader.read_int16() result.append( (f_name, f_table_oid, f_column_num, f_dt_oid, f_dt_size, f_dt_mod, f_format)) return result ================================================ FILE: asyncpg/protocol/protocol.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from libc.stdint cimport int16_t, int32_t, uint16_t, \ uint32_t, int64_t, uint64_t from asyncpg.pgproto.debug cimport PG_DEBUG from asyncpg.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, FRBuffer, ) from asyncpg.pgproto cimport pgproto include "consts.pxi" include "pgtypes.pxi" include "codecs/base.pxd" include "settings.pxd" include "coreproto.pxd" include "prepared_stmt.pxd" cdef class BaseProtocol(CoreProtocol): cdef: object loop ConnectionSettings settings object cancel_sent_waiter object cancel_waiter object waiter bint return_extra object create_future object timeout_handle object conref type record_class bint is_reading str last_query bint writing_paused bint closing readonly uint64_t queries_count bint _is_ssl PreparedStatementState statement cdef get_connection(self) cdef _get_timeout_impl(self, timeout) cdef _check_state(self) cdef _new_waiter(self, timeout) cdef _coreproto_error(self) cdef _on_result__connect(self, object waiter) cdef _on_result__prepare(self, object waiter) cdef _on_result__bind_and_exec(self, object waiter) cdef _on_result__close_stmt_or_portal(self, object waiter) cdef _on_result__simple_query(self, object waiter) cdef _on_result__bind(self, object waiter) cdef _on_result__copy_out(self, object waiter) cdef _on_result__copy_in(self, object waiter) cdef _handle_waiter_on_connection_lost(self, cause) cdef _dispatch_result(self) cdef inline resume_reading(self) cdef inline pause_reading(self) ================================================ FILE: asyncpg/protocol/protocol.pyi ================================================ import asyncio import asyncio.protocols import hmac from codecs import CodecInfo from collections.abc import Callable, Iterable, Sequence from hashlib import md5, sha256 from typing import ( Any, ClassVar, Final, Generic, Literal, NewType, TypeVar, final, overload, ) from typing_extensions import TypeAlias import asyncpg.pgproto.pgproto from ..connect_utils import _ConnectionParameters from ..pgproto.pgproto import WriteBuffer from ..types import Attribute, Type from .record import Record _Record = TypeVar('_Record', bound=Record) _OtherRecord = TypeVar('_OtherRecord', bound=Record) _PreparedStatementState = TypeVar( '_PreparedStatementState', bound=PreparedStatementState[Any] ) _NoTimeoutType = NewType('_NoTimeoutType', object) _TimeoutType: TypeAlias = float | None | _NoTimeoutType BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]] BUILTIN_TYPE_OID_MAP: Final[dict[int, str]] NO_TIMEOUT: Final[_NoTimeoutType] hashlib_md5 = md5 @final class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext): __pyx_vtable__: Any def __init__(self, conn_key: object) -> None: ... def add_python_codec( self, typeoid: int, typename: str, typeschema: str, typeinfos: Iterable[object], typekind: str, encoder: Callable[[Any], Any], decoder: Callable[[Any], Any], format: object, ) -> Any: ... def clear_type_cache(self) -> None: ... def get_data_codec( self, oid: int, format: object = ..., ignore_custom_codec: bool = ... ) -> Any: ... def get_text_codec(self) -> CodecInfo: ... def register_data_types(self, types: Iterable[object]) -> None: ... def remove_python_codec( self, typeoid: int, typename: str, typeschema: str ) -> None: ... def set_builtin_type_codec( self, typeoid: int, typename: str, typeschema: str, typekind: str, alias_to: str, format: object = ..., ) -> Any: ... def __getattr__(self, name: str) -> Any: ... def __reduce__(self) -> Any: ... @final class PreparedStatementState(Generic[_Record]): closed: bool prepared: bool name: str query: str refs: int record_class: type[_Record] ignore_custom_codec: bool __pyx_vtable__: Any def __init__( self, name: str, query: str, protocol: BaseProtocol[Any], record_class: type[_Record], ignore_custom_codec: bool, ) -> None: ... def _get_parameters(self) -> tuple[Type, ...]: ... def _get_attributes(self) -> tuple[Attribute, ...]: ... def _init_types(self) -> set[int]: ... def _init_codecs(self) -> None: ... def attach(self) -> None: ... def detach(self) -> None: ... def mark_closed(self) -> None: ... def mark_unprepared(self) -> None: ... def __reduce__(self) -> Any: ... class CoreProtocol: backend_pid: Any backend_secret: Any __pyx_vtable__: Any def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ... def is_in_transaction(self) -> bool: ... def __reduce__(self) -> Any: ... class BaseProtocol(CoreProtocol, Generic[_Record]): queries_count: Any is_ssl: bool __pyx_vtable__: Any def __init__( self, addr: object, connected_fut: object, con_params: _ConnectionParameters, record_class: type[_Record], loop: object, ) -> None: ... def set_connection(self, connection: object) -> None: ... def get_server_pid(self, *args: object, **kwargs: object) -> int: ... def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ... def get_record_class(self) -> type[_Record]: ... def abort(self) -> None: ... async def bind( self, state: PreparedStatementState[_OtherRecord], args: Sequence[object], portal_name: str, timeout: _TimeoutType, ) -> Any: ... @overload async def bind_execute( self, state: PreparedStatementState[_OtherRecord], args: Sequence[object], portal_name: str, limit: int, return_extra: Literal[False], timeout: _TimeoutType, ) -> list[_OtherRecord]: ... @overload async def bind_execute( self, state: PreparedStatementState[_OtherRecord], args: Sequence[object], portal_name: str, limit: int, return_extra: Literal[True], timeout: _TimeoutType, ) -> tuple[list[_OtherRecord], bytes, bool]: ... @overload async def bind_execute( self, state: PreparedStatementState[_OtherRecord], args: Sequence[object], portal_name: str, limit: int, return_extra: bool, timeout: _TimeoutType, ) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ... async def bind_execute_many( self, state: PreparedStatementState[_OtherRecord], args: Iterable[Sequence[object]], portal_name: str, timeout: _TimeoutType, ) -> None: ... async def close(self, timeout: _TimeoutType) -> None: ... def _get_timeout(self, timeout: _TimeoutType) -> float | None: ... def _is_cancelling(self) -> bool: ... async def _wait_for_cancellation(self) -> None: ... async def close_statement( self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType ) -> Any: ... async def copy_in(self, *args: object, **kwargs: object) -> str: ... async def copy_out(self, *args: object, **kwargs: object) -> str: ... async def execute(self, *args: object, **kwargs: object) -> Any: ... def is_closed(self, *args: object, **kwargs: object) -> Any: ... def is_connected(self, *args: object, **kwargs: object) -> Any: ... def data_received(self, data: object) -> None: ... def connection_made(self, transport: object) -> None: ... def connection_lost(self, exc: Exception | None) -> None: ... def pause_writing(self, *args: object, **kwargs: object) -> Any: ... @overload async def prepare( self, stmt_name: str, query: str, timeout: float | None = ..., *, state: _PreparedStatementState, ignore_custom_codec: bool = ..., record_class: None, ) -> _PreparedStatementState: ... @overload async def prepare( self, stmt_name: str, query: str, timeout: float | None = ..., *, state: None = ..., ignore_custom_codec: bool = ..., record_class: type[_OtherRecord], ) -> PreparedStatementState[_OtherRecord]: ... async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ... async def query(self, *args: object, **kwargs: object) -> str: ... def resume_writing(self, *args: object, **kwargs: object) -> Any: ... def __reduce__(self) -> Any: ... @final class Codec: __pyx_vtable__: Any def __reduce__(self) -> Any: ... class DataCodecConfig: __pyx_vtable__: Any def __init__(self) -> None: ... def add_python_codec( self, typeoid: int, typename: str, typeschema: str, typekind: str, typeinfos: Iterable[object], encoder: Callable[[ConnectionSettings, WriteBuffer, object], object], decoder: Callable[..., object], format: object, xformat: object, ) -> Any: ... def add_types(self, types: Iterable[object]) -> Any: ... def clear_type_cache(self) -> None: ... def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ... def remove_python_codec( self, typeoid: int, typename: str, typeschema: str ) -> Any: ... def set_builtin_type_codec( self, typeoid: int, typename: str, typeschema: str, typekind: str, alias_to: str, format: object = ..., ) -> Any: ... def __reduce__(self) -> Any: ... class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ... class Timer: def __init__(self, budget: float | None) -> None: ... def __enter__(self) -> None: ... def __exit__(self, et: object, e: object, tb: object) -> None: ... def get_remaining_budget(self) -> float: ... def has_budget_greater_than(self, amount: float) -> bool: ... @final class SCRAMAuthentication: AUTHENTICATION_METHODS: ClassVar[list[str]] DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int] DIGEST = sha256 REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]] REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]] SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]] authentication_method: bytes authorization_message: bytes | None client_channel_binding: bytes client_first_message_bare: bytes | None client_nonce: bytes | None client_proof: bytes | None password_salt: bytes | None password_iterations: int server_first_message: bytes | None server_key: hmac.HMAC | None server_nonce: bytes | None ================================================ FILE: asyncpg/protocol/protocol.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 # cython: language_level=3 cimport cython cimport cpython import asyncio import builtins import codecs import collections.abc import socket import time import weakref from asyncpg.pgproto.pgproto cimport ( WriteBuffer, ReadBuffer, FRBuffer, frb_init, frb_read, frb_read_all, frb_slice_from, frb_check, frb_set_len, frb_get_len, ) from asyncpg.pgproto cimport pgproto from asyncpg.protocol cimport cpythonx from asyncpg.protocol cimport recordcapi from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ int32_t, uint32_t, int64_t, uint64_t, \ INT32_MAX, UINT32_MAX from asyncpg.exceptions import _base as apg_exc_base from asyncpg import compat from asyncpg import types as apg_types from asyncpg import exceptions as apg_exc from asyncpg.pgproto cimport hton from asyncpg.protocol.record import Record, RecordDescriptor include "consts.pxi" include "pgtypes.pxi" include "encodings.pyx" include "settings.pyx" include "codecs/base.pyx" include "codecs/textutils.pyx" # register codecs provided by pgproto include "codecs/pgproto.pyx" # nonscalar include "codecs/array.pyx" include "codecs/range.pyx" include "codecs/record.pyx" include "coreproto.pyx" include "prepared_stmt.pyx" NO_TIMEOUT = object() cdef class BaseProtocol(CoreProtocol): def __init__(self, addr, connected_fut, con_params, record_class: type, loop): # type of `con_params` is `_ConnectionParameters` CoreProtocol.__init__(self, addr, con_params) self.loop = loop self.transport = None self.waiter = connected_fut self.cancel_waiter = None self.cancel_sent_waiter = None self.settings = ConnectionSettings((addr, con_params.database)) self.record_class = record_class self.statement = None self.return_extra = False self.last_query = None self.closing = False self.is_reading = True self.writing_allowed = asyncio.Event() self.writing_allowed.set() self.timeout_handle = None self.queries_count = 0 self._is_ssl = False try: self.create_future = loop.create_future except AttributeError: self.create_future = self._create_future_fallback def set_connection(self, connection): self.conref = weakref.ref(connection) cdef get_connection(self): if self.conref is not None: return self.conref() else: return None def get_server_pid(self): return self.backend_pid def get_settings(self): return self.settings def get_record_class(self): return self.record_class cdef inline resume_reading(self): if not self.is_reading: self.is_reading = True self.transport.resume_reading() cdef inline pause_reading(self): if self.is_reading: self.is_reading = False self.transport.pause_reading() async def prepare(self, stmt_name, query, timeout, *, PreparedStatementState state=None, ignore_custom_codec=False, record_class): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() timeout = self._get_timeout_impl(timeout) waiter = self._new_waiter(timeout) try: self._prepare_and_describe(stmt_name, query) # network op self.last_query = query if state is None: state = PreparedStatementState( stmt_name, query, self, record_class, ignore_custom_codec) self.statement = state except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter async def bind_execute( self, state: PreparedStatementState, args, portal_name: str, limit: int, return_extra: bool, timeout, ): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() timeout = self._get_timeout_impl(timeout) args_buf = state._encode_bind_msg(args) waiter = self._new_waiter(timeout) try: if not state.prepared: self._send_parse_message(state.name, state.query) self._bind_execute( portal_name, state.name, args_buf, limit) # network op self.last_query = state.query self.statement = state self.return_extra = return_extra self.queries_count += 1 except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter async def bind_execute_many( self, state: PreparedStatementState, args, portal_name: str, timeout, return_rows: bool, ): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() timeout = self._get_timeout_impl(timeout) timer = Timer(timeout) # Make sure the argument sequence is encoded lazily with # this generator expression to keep the memory pressure under # control. data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args)) arg_bufs = iter(data_gen) waiter = self._new_waiter(timeout) try: if not state.prepared: self._send_parse_message(state.name, state.query) more = self._bind_execute_many( portal_name, state.name, arg_bufs, return_rows) # network op self.last_query = state.query self.statement = state self.return_extra = False self.queries_count += 1 while more: with timer: await compat.wait_for( self.writing_allowed.wait(), timeout=timer.get_remaining_budget()) # On Windows the above event somehow won't allow context # switch, so forcing one with sleep(0) here await asyncio.sleep(0) if not timer.has_budget_greater_than(0): raise asyncio.TimeoutError more = self._bind_execute_many_more() # network op except asyncio.TimeoutError as e: self._bind_execute_many_fail(e) # network op except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter async def bind(self, PreparedStatementState state, args, str portal_name, timeout): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() timeout = self._get_timeout_impl(timeout) args_buf = state._encode_bind_msg(args) waiter = self._new_waiter(timeout) try: self._bind( portal_name, state.name, args_buf) # network op self.last_query = state.query self.statement = state except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter async def execute(self, PreparedStatementState state, str portal_name, int limit, return_extra, timeout): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() timeout = self._get_timeout_impl(timeout) waiter = self._new_waiter(timeout) try: self._execute( portal_name, limit) # network op self.last_query = state.query self.statement = state self.return_extra = return_extra self.queries_count += 1 except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter async def close_portal(self, str portal_name, timeout): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() timeout = self._get_timeout_impl(timeout) waiter = self._new_waiter(timeout) try: self._close( portal_name, True) # network op except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter async def query(self, query, timeout): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() # query() needs to call _get_timeout instead of _get_timeout_impl # for consistent validation, as it is called differently from # prepare/bind/execute methods. timeout = self._get_timeout(timeout) waiter = self._new_waiter(timeout) try: self._simple_query(query) # network op self.last_query = query self.queries_count += 1 except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter async def copy_out(self, copy_stmt, sink, timeout): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() timeout = self._get_timeout_impl(timeout) timer = Timer(timeout) # The copy operation is guarded by a single timeout # on the top level. waiter = self._new_waiter(timer.get_remaining_budget()) self._copy_out(copy_stmt) try: while True: self.resume_reading() with timer: buffer, done, status_msg = await waiter # buffer will be empty if CopyDone was received apart from # the last CopyData message. if buffer: try: with timer: await compat.wait_for( sink(buffer), timeout=timer.get_remaining_budget()) except (Exception, asyncio.CancelledError) as ex: # Abort the COPY operation on any error in # output sink. self._request_cancel() # Make asyncio shut up about unretrieved # QueryCanceledError waiter.add_done_callback(lambda f: f.exception()) raise # done will be True upon receipt of CopyDone. if done: break waiter = self._new_waiter(timer.get_remaining_budget()) finally: self.resume_reading() return status_msg async def copy_in(self, copy_stmt, reader, data, records, PreparedStatementState record_stmt, timeout): cdef: WriteBuffer wbuf ssize_t num_cols Codec codec if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() timeout = self._get_timeout_impl(timeout) timer = Timer(timeout) waiter = self._new_waiter(timer.get_remaining_budget()) # Initiate COPY IN. self._copy_in(copy_stmt) try: if record_stmt is not None: # copy_in_records in binary mode wbuf = WriteBuffer.new() # Signature wbuf.write_bytes(_COPY_SIGNATURE) # Flags field wbuf.write_int32(0) # No header extension wbuf.write_int32(0) record_stmt._ensure_rows_decoder() codecs = record_stmt.rows_codecs num_cols = len(codecs) settings = self.settings for codec in codecs: if (not codec.has_encoder() or codec.format != PG_FORMAT_BINARY): raise apg_exc.InternalClientError( 'no binary format encoder for ' 'type {} (OID {})'.format(codec.name, codec.oid)) if isinstance(records, collections.abc.AsyncIterable): async for row in records: # Tuple header wbuf.write_int16(num_cols) # Tuple data for i in range(num_cols): item = row[i] if item is None: wbuf.write_int32(-1) else: codec = cpython.PyTuple_GET_ITEM( codecs, i) codec.encode(settings, wbuf, item) if wbuf.len() >= _COPY_BUFFER_SIZE: with timer: await self.writing_allowed.wait() self._write_copy_data_msg(wbuf) wbuf = WriteBuffer.new() else: for row in records: # Tuple header wbuf.write_int16(num_cols) # Tuple data for i in range(num_cols): item = row[i] if item is None: wbuf.write_int32(-1) else: codec = cpython.PyTuple_GET_ITEM( codecs, i) codec.encode(settings, wbuf, item) if wbuf.len() >= _COPY_BUFFER_SIZE: with timer: await self.writing_allowed.wait() self._write_copy_data_msg(wbuf) wbuf = WriteBuffer.new() # End of binary copy. wbuf.write_int16(-1) self._write_copy_data_msg(wbuf) elif reader is not None: try: aiter = reader.__aiter__ except AttributeError: raise TypeError('reader is not an asynchronous iterable') else: iterator = aiter() try: while True: # We rely on protocol flow control to moderate the # rate of data messages. with timer: await self.writing_allowed.wait() with timer: chunk = await compat.wait_for( iterator.__anext__(), timeout=timer.get_remaining_budget()) self._write_copy_data_msg(chunk) except builtins.StopAsyncIteration: pass else: # Buffer passed in directly. await self.writing_allowed.wait() self._write_copy_data_msg(data) except asyncio.TimeoutError: self._write_copy_fail_msg('TimeoutError') self._on_timeout(self.waiter) try: await waiter except TimeoutError: raise else: raise apg_exc.InternalClientError('TimoutError was not raised') except (Exception, asyncio.CancelledError) as e: self._write_copy_fail_msg(str(e)) self._request_cancel() # Make asyncio shut up about unretrieved QueryCanceledError waiter.add_done_callback(lambda f: f.exception()) raise self._write_copy_done_msg() status_msg = await waiter return status_msg async def close_statement(self, PreparedStatementState state, timeout): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None self._check_state() if state.refs != 0: raise apg_exc.InternalClientError( 'cannot close prepared statement; refs == {} != 0'.format( state.refs)) timeout = self._get_timeout_impl(timeout) waiter = self._new_waiter(timeout) try: self._close(state.name, False) # network op state.closed = True except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter def is_closed(self): return self.closing def is_connected(self): return not self.closing and self.con_status == CONNECTION_OK def abort(self): if self.closing: return self.closing = True self._handle_waiter_on_connection_lost(None) self._terminate() self.transport.abort() self.transport = None async def close(self, timeout): if self.closing: return self.closing = True if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None if self.cancel_waiter is not None: await self.cancel_waiter if self.waiter is not None: # If there is a query running, cancel it self._request_cancel() await self.cancel_sent_waiter self.cancel_sent_waiter = None if self.cancel_waiter is not None: await self.cancel_waiter assert self.waiter is None timeout = self._get_timeout_impl(timeout) # Ask the server to terminate the connection and wait for it # to drop. self.waiter = self._new_waiter(timeout) self._terminate() try: await self.waiter except ConnectionResetError: # There appears to be a difference in behaviour of asyncio # in Windows, where, instead of calling protocol.connection_lost() # a ConnectionResetError will be thrown into the task. pass finally: self.waiter = None self.transport.abort() def _request_cancel(self): self.cancel_waiter = self.create_future() self.cancel_sent_waiter = self.create_future() con = self.get_connection() if con is not None: # if 'con' is None it means that the connection object has been # garbage collected and that the transport will soon be aborted. con._cancel_current_command(self.cancel_sent_waiter) else: self.loop.call_exception_handler({ 'message': 'asyncpg.Protocol has no reference to its ' 'Connection object and yet a cancellation ' 'was requested. Please report this at ' 'github.com/magicstack/asyncpg.' }) self.abort() if self.state == PROTOCOL_PREPARE: # we need to send a SYNC to server if we cancel during the PREPARE phase # because the PREPARE sequence does not send a SYNC itself. # we cannot send this extra SYNC if we are not in PREPARE phase, # because then we would issue two SYNCs and we would get two ReadyForQuery # replies, which our current state machine implementation cannot handle self._write(SYNC_MESSAGE) self._set_state(PROTOCOL_CANCELLED) def _on_timeout(self, fut): if self.waiter is not fut or fut.done() or \ self.cancel_waiter is not None or \ self.timeout_handle is None: return self._request_cancel() self.waiter.set_exception(asyncio.TimeoutError()) def _on_waiter_completed(self, fut): if self.timeout_handle: self.timeout_handle.cancel() self.timeout_handle = None if fut is not self.waiter or self.cancel_waiter is not None: return if fut.cancelled(): self._request_cancel() def _create_future_fallback(self): return asyncio.Future(loop=self.loop) cdef _handle_waiter_on_connection_lost(self, cause): if self.waiter is not None and not self.waiter.done(): exc = apg_exc.ConnectionDoesNotExistError( 'connection was closed in the middle of ' 'operation') if cause is not None: exc.__cause__ = cause self.waiter.set_exception(exc) self.waiter = None cdef _set_server_parameter(self, name, val): self.settings.add_setting(name, val) def _get_timeout(self, timeout): if timeout is not None: try: if type(timeout) is bool: raise ValueError timeout = float(timeout) except ValueError: raise ValueError( 'invalid timeout value: expected non-negative float ' '(got {!r})'.format(timeout)) from None return self._get_timeout_impl(timeout) cdef inline _get_timeout_impl(self, timeout): if timeout is None: timeout = self.get_connection()._config.command_timeout elif timeout is NO_TIMEOUT: timeout = None else: timeout = float(timeout) if timeout is not None and timeout <= 0: raise asyncio.TimeoutError() return timeout cdef _check_state(self): if self.cancel_waiter is not None: raise apg_exc.InterfaceError( 'cannot perform operation: another operation is cancelling') if self.closing: raise apg_exc.InterfaceError( 'cannot perform operation: connection is closed') if self.waiter is not None or self.timeout_handle is not None: raise apg_exc.InterfaceError( 'cannot perform operation: another operation is in progress') def _is_cancelling(self): return ( self.cancel_waiter is not None or self.cancel_sent_waiter is not None ) async def _wait_for_cancellation(self): if self.cancel_sent_waiter is not None: await self.cancel_sent_waiter self.cancel_sent_waiter = None if self.cancel_waiter is not None: await self.cancel_waiter cdef _coreproto_error(self): try: if self.waiter is not None: if not self.waiter.done(): raise apg_exc.InternalClientError( 'waiter is not done while handling critical ' 'protocol error') self.waiter = None finally: self.abort() cdef _new_waiter(self, timeout): if self.waiter is not None: raise apg_exc.InterfaceError( 'cannot perform operation: another operation is in progress') self.waiter = self.create_future() if timeout is not None: self.timeout_handle = self.loop.call_later( timeout, self._on_timeout, self.waiter) self.waiter.add_done_callback(self._on_waiter_completed) return self.waiter cdef _on_result__connect(self, object waiter): waiter.set_result(True) cdef _on_result__prepare(self, object waiter): if PG_DEBUG: if self.statement is None: raise apg_exc.InternalClientError( '_on_result__prepare: statement is None') if self.result_param_desc is not None: self.statement._set_args_desc(self.result_param_desc) if self.result_row_desc is not None: self.statement._set_row_desc(self.result_row_desc) waiter.set_result(self.statement) cdef _on_result__bind_and_exec(self, object waiter): if self.return_extra: waiter.set_result(( self.result, self.result_status_msg, self.result_execute_completed)) else: waiter.set_result(self.result) cdef _on_result__bind(self, object waiter): waiter.set_result(self.result) cdef _on_result__close_stmt_or_portal(self, object waiter): waiter.set_result(self.result) cdef _on_result__simple_query(self, object waiter): waiter.set_result(self.result_status_msg.decode(self.encoding)) cdef _on_result__copy_out(self, object waiter): cdef bint copy_done = self.state == PROTOCOL_COPY_OUT_DONE if copy_done: status_msg = self.result_status_msg.decode(self.encoding) else: status_msg = None # We need to put some backpressure on Postgres # here in case the sink is slow to process the output. self.pause_reading() waiter.set_result((self.result, copy_done, status_msg)) cdef _on_result__copy_in(self, object waiter): status_msg = self.result_status_msg.decode(self.encoding) waiter.set_result(status_msg) cdef _decode_row(self, const char* buf, ssize_t buf_len): if PG_DEBUG: if self.statement is None: raise apg_exc.InternalClientError( '_decode_row: statement is None') return self.statement._decode_row(buf, buf_len) cdef _dispatch_result(self): waiter = self.waiter self.waiter = None if PG_DEBUG: if waiter is None: raise apg_exc.InternalClientError('_on_result: waiter is None') if waiter.cancelled(): return if waiter.done(): raise apg_exc.InternalClientError('_on_result: waiter is done') if self.result_type == RESULT_FAILED: if isinstance(self.result, dict): exc = apg_exc_base.PostgresError.new( self.result, query=self.last_query) else: exc = self.result waiter.set_exception(exc) return try: if self.state == PROTOCOL_AUTH: self._on_result__connect(waiter) elif self.state == PROTOCOL_PREPARE: self._on_result__prepare(waiter) elif self.state == PROTOCOL_BIND_EXECUTE: self._on_result__bind_and_exec(waiter) elif self.state == PROTOCOL_BIND_EXECUTE_MANY: self._on_result__bind_and_exec(waiter) elif self.state == PROTOCOL_EXECUTE: self._on_result__bind_and_exec(waiter) elif self.state == PROTOCOL_BIND: self._on_result__bind(waiter) elif self.state == PROTOCOL_CLOSE_STMT_PORTAL: self._on_result__close_stmt_or_portal(waiter) elif self.state == PROTOCOL_SIMPLE_QUERY: self._on_result__simple_query(waiter) elif (self.state == PROTOCOL_COPY_OUT_DATA or self.state == PROTOCOL_COPY_OUT_DONE): self._on_result__copy_out(waiter) elif self.state == PROTOCOL_COPY_IN_DATA: self._on_result__copy_in(waiter) elif self.state == PROTOCOL_TERMINATING: # We are waiting for the connection to drop, so # ignore any stray results at this point. pass else: raise apg_exc.InternalClientError( 'got result for unknown protocol state {}'. format(self.state)) except Exception as exc: waiter.set_exception(exc) cdef _on_result(self): if self.timeout_handle is not None: self.timeout_handle.cancel() self.timeout_handle = None if self.cancel_waiter is not None: # We have received the result of a cancelled command. if not self.cancel_waiter.done(): # The cancellation future might have been cancelled # by the cancellation of the entire task running the query. self.cancel_waiter.set_result(None) self.cancel_waiter = None if self.waiter is not None and self.waiter.done(): self.waiter = None if self.waiter is None: return try: self._dispatch_result() finally: self.statement = None self.last_query = None self.return_extra = False cdef _on_notice(self, parsed): con = self.get_connection() if con is not None: con._process_log_message(parsed, self.last_query) cdef _on_notification(self, pid, channel, payload): con = self.get_connection() if con is not None: con._process_notification(pid, channel, payload) cdef _on_connection_lost(self, exc): if self.closing: # The connection was lost because # Protocol.close() was called if self.waiter is not None and not self.waiter.done(): if exc is None: self.waiter.set_result(None) else: self.waiter.set_exception(exc) self.waiter = None else: # The connection was lost because it was # terminated or due to another error; # Throw an error in any awaiting waiter. self.closing = True # Cleanup the connection resources, including, possibly, # releasing the pool holder. con = self.get_connection() if con is not None: con._cleanup() self._handle_waiter_on_connection_lost(exc) cdef _write(self, buf): self.transport.write(memoryview(buf)) cdef _writelines(self, list buffers): self.transport.writelines(buffers) # asyncio callbacks: def data_received(self, data): self.buffer.feed_data(data) self._read_server_messages() def connection_made(self, transport): self.transport = transport sock = transport.get_extra_info('socket') if (sock is not None and (not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX)): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: self._connect() except Exception as ex: transport.abort() self.con_status = CONNECTION_BAD self._set_state(PROTOCOL_FAILED) self._on_error(ex) def connection_lost(self, exc): self.con_status = CONNECTION_BAD self._set_state(PROTOCOL_FAILED) self._on_connection_lost(exc) def pause_writing(self): self.writing_allowed.clear() def resume_writing(self): self.writing_allowed.set() @property def is_ssl(self): return self._is_ssl @is_ssl.setter def is_ssl(self, value): self._is_ssl = value class Timer: def __init__(self, budget): self._budget = budget self._started = 0 def __enter__(self): if self._budget is not None: self._started = time.monotonic() def __exit__(self, et, e, tb): if self._budget is not None: self._budget -= time.monotonic() - self._started def get_remaining_budget(self): return self._budget def has_budget_greater_than(self, amount): if self._budget is None: # Unlimited budget. return True else: return self._budget > amount class Protocol(BaseProtocol, asyncio.Protocol): pass def _create_record(object mapping, tuple elems): # Exposed only for testing purposes. cdef: object rec int32_t i if mapping is None: desc = RecordDescriptor({}, ()) else: desc = RecordDescriptor( mapping, tuple(mapping) if mapping else ()) rec = desc.make_record(Record, len(elems)) for i in range(len(elems)): elem = elems[i] cpython.Py_INCREF(elem) recordcapi.ApgRecord_SET_ITEM(rec, i, elem) return rec ================================================ FILE: asyncpg/protocol/record/pythoncapi_compat.h ================================================ // Header file providing new C API functions to old Python versions. // // File distributed under the Zero Clause BSD (0BSD) license. // Copyright Contributors to the pythoncapi_compat project. // // Homepage: // https://github.com/python/pythoncapi_compat // // Latest version: // https://raw.githubusercontent.com/python/pythoncapi-compat/main/pythoncapi_compat.h // // SPDX-License-Identifier: 0BSD #ifndef PYTHONCAPI_COMPAT #define PYTHONCAPI_COMPAT #ifdef __cplusplus extern "C" { #endif #include #include // offsetof() // Python 3.11.0b4 added PyFrame_Back() to Python.h #if PY_VERSION_HEX < 0x030b00B4 && !defined(PYPY_VERSION) # include "frameobject.h" // PyFrameObject, PyFrame_GetBack() #endif #if PY_VERSION_HEX < 0x030C00A3 # include // T_SHORT, READONLY #endif #ifndef _Py_CAST # define _Py_CAST(type, expr) ((type)(expr)) #endif // Static inline functions should use _Py_NULL rather than using directly NULL // to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, // _Py_NULL is defined as nullptr. #ifndef _Py_NULL # if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ || (defined(__cplusplus) && __cplusplus >= 201103) # define _Py_NULL nullptr # else # define _Py_NULL NULL # endif #endif // Cast argument to PyObject* type. #ifndef _PyObject_CAST # define _PyObject_CAST(op) _Py_CAST(PyObject*, op) #endif #ifndef Py_BUILD_ASSERT # define Py_BUILD_ASSERT(cond) \ do { \ (void)sizeof(char [1 - 2 * !(cond)]); \ } while(0) #endif // bpo-42262 added Py_NewRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) static inline PyObject* _Py_NewRef(PyObject *obj) { Py_INCREF(obj); return obj; } #define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) #endif // bpo-42262 added Py_XNewRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_XNewRef) static inline PyObject* _Py_XNewRef(PyObject *obj) { Py_XINCREF(obj); return obj; } #define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj)) #endif // bpo-39573 added Py_SET_REFCNT() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_REFCNT) static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) { ob->ob_refcnt = refcnt; } #define Py_SET_REFCNT(ob, refcnt) _Py_SET_REFCNT(_PyObject_CAST(ob), refcnt) #endif // Py_SETREF() and Py_XSETREF() were added to Python 3.5.2. // It is excluded from the limited C API. #if (PY_VERSION_HEX < 0x03050200 && !defined(Py_SETREF)) && !defined(Py_LIMITED_API) #define Py_SETREF(dst, src) \ do { \ PyObject **_tmp_dst_ptr = _Py_CAST(PyObject**, &(dst)); \ PyObject *_tmp_dst = (*_tmp_dst_ptr); \ *_tmp_dst_ptr = _PyObject_CAST(src); \ Py_DECREF(_tmp_dst); \ } while (0) #define Py_XSETREF(dst, src) \ do { \ PyObject **_tmp_dst_ptr = _Py_CAST(PyObject**, &(dst)); \ PyObject *_tmp_dst = (*_tmp_dst_ptr); \ *_tmp_dst_ptr = _PyObject_CAST(src); \ Py_XDECREF(_tmp_dst); \ } while (0) #endif // bpo-43753 added Py_Is(), Py_IsNone(), Py_IsTrue() and Py_IsFalse() // to Python 3.10.0b1. #if PY_VERSION_HEX < 0x030A00B1 && !defined(Py_Is) # define Py_Is(x, y) ((x) == (y)) #endif #if PY_VERSION_HEX < 0x030A00B1 && !defined(Py_IsNone) # define Py_IsNone(x) Py_Is(x, Py_None) #endif #if (PY_VERSION_HEX < 0x030A00B1 || defined(PYPY_VERSION)) && !defined(Py_IsTrue) # define Py_IsTrue(x) Py_Is(x, Py_True) #endif #if (PY_VERSION_HEX < 0x030A00B1 || defined(PYPY_VERSION)) && !defined(Py_IsFalse) # define Py_IsFalse(x) Py_Is(x, Py_False) #endif // bpo-39573 added Py_SET_TYPE() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type) { ob->ob_type = type; } #define Py_SET_TYPE(ob, type) _Py_SET_TYPE(_PyObject_CAST(ob), type) #endif // bpo-39573 added Py_SET_SIZE() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_SIZE) static inline void _Py_SET_SIZE(PyVarObject *ob, Py_ssize_t size) { ob->ob_size = size; } #define Py_SET_SIZE(ob, size) _Py_SET_SIZE((PyVarObject*)(ob), size) #endif // bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 #if PY_VERSION_HEX < 0x030900B1 || defined(PYPY_VERSION) static inline PyCodeObject* PyFrame_GetCode(PyFrameObject *frame) { assert(frame != _Py_NULL); assert(frame->f_code != _Py_NULL); return _Py_CAST(PyCodeObject*, Py_NewRef(frame->f_code)); } #endif static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) { PyCodeObject *code = PyFrame_GetCode(frame); Py_DECREF(code); return code; } // bpo-40421 added PyFrame_GetBack() to Python 3.9.0b1 #if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) static inline PyFrameObject* PyFrame_GetBack(PyFrameObject *frame) { assert(frame != _Py_NULL); return _Py_CAST(PyFrameObject*, Py_XNewRef(frame->f_back)); } #endif #if !defined(PYPY_VERSION) static inline PyFrameObject* _PyFrame_GetBackBorrow(PyFrameObject *frame) { PyFrameObject *back = PyFrame_GetBack(frame); Py_XDECREF(back); return back; } #endif // bpo-40421 added PyFrame_GetLocals() to Python 3.11.0a7 #if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) static inline PyObject* PyFrame_GetLocals(PyFrameObject *frame) { #if PY_VERSION_HEX >= 0x030400B1 if (PyFrame_FastToLocalsWithError(frame) < 0) { return NULL; } #else PyFrame_FastToLocals(frame); #endif return Py_NewRef(frame->f_locals); } #endif // bpo-40421 added PyFrame_GetGlobals() to Python 3.11.0a7 #if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) static inline PyObject* PyFrame_GetGlobals(PyFrameObject *frame) { return Py_NewRef(frame->f_globals); } #endif // bpo-40421 added PyFrame_GetBuiltins() to Python 3.11.0a7 #if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) static inline PyObject* PyFrame_GetBuiltins(PyFrameObject *frame) { return Py_NewRef(frame->f_builtins); } #endif // bpo-40421 added PyFrame_GetLasti() to Python 3.11.0b1 #if PY_VERSION_HEX < 0x030B00B1 && !defined(PYPY_VERSION) static inline int PyFrame_GetLasti(PyFrameObject *frame) { #if PY_VERSION_HEX >= 0x030A00A7 // bpo-27129: Since Python 3.10.0a7, f_lasti is an instruction offset, // not a bytes offset anymore. Python uses 16-bit "wordcode" (2 bytes) // instructions. if (frame->f_lasti < 0) { return -1; } return frame->f_lasti * 2; #else return frame->f_lasti; #endif } #endif // gh-91248 added PyFrame_GetVar() to Python 3.12.0a2 #if PY_VERSION_HEX < 0x030C00A2 && !defined(PYPY_VERSION) static inline PyObject* PyFrame_GetVar(PyFrameObject *frame, PyObject *name) { PyObject *locals, *value; locals = PyFrame_GetLocals(frame); if (locals == NULL) { return NULL; } #if PY_VERSION_HEX >= 0x03000000 value = PyDict_GetItemWithError(locals, name); #else value = _PyDict_GetItemWithError(locals, name); #endif Py_DECREF(locals); if (value == NULL) { if (PyErr_Occurred()) { return NULL; } #if PY_VERSION_HEX >= 0x03000000 PyErr_Format(PyExc_NameError, "variable %R does not exist", name); #else PyErr_SetString(PyExc_NameError, "variable does not exist"); #endif return NULL; } return Py_NewRef(value); } #endif // gh-91248 added PyFrame_GetVarString() to Python 3.12.0a2 #if PY_VERSION_HEX < 0x030C00A2 && !defined(PYPY_VERSION) static inline PyObject* PyFrame_GetVarString(PyFrameObject *frame, const char *name) { PyObject *name_obj, *value; #if PY_VERSION_HEX >= 0x03000000 name_obj = PyUnicode_FromString(name); #else name_obj = PyString_FromString(name); #endif if (name_obj == NULL) { return NULL; } value = PyFrame_GetVar(frame, name_obj); Py_DECREF(name_obj); return value; } #endif // bpo-39947 added PyThreadState_GetInterpreter() to Python 3.9.0a5 #if PY_VERSION_HEX < 0x030900A5 || (defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030B0000) static inline PyInterpreterState * PyThreadState_GetInterpreter(PyThreadState *tstate) { assert(tstate != _Py_NULL); return tstate->interp; } #endif // bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 #if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) static inline PyFrameObject* PyThreadState_GetFrame(PyThreadState *tstate) { assert(tstate != _Py_NULL); return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); } #endif #if !defined(PYPY_VERSION) static inline PyFrameObject* _PyThreadState_GetFrameBorrow(PyThreadState *tstate) { PyFrameObject *frame = PyThreadState_GetFrame(tstate); Py_XDECREF(frame); return frame; } #endif // bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a5 #if PY_VERSION_HEX < 0x030900A5 || defined(PYPY_VERSION) static inline PyInterpreterState* PyInterpreterState_Get(void) { PyThreadState *tstate; PyInterpreterState *interp; tstate = PyThreadState_GET(); if (tstate == _Py_NULL) { Py_FatalError("GIL released (tstate is NULL)"); } interp = tstate->interp; if (interp == _Py_NULL) { Py_FatalError("no current interpreter"); } return interp; } #endif // bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a6 #if 0x030700A1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) static inline uint64_t PyThreadState_GetID(PyThreadState *tstate) { assert(tstate != _Py_NULL); return tstate->id; } #endif // bpo-43760 added PyThreadState_EnterTracing() to Python 3.11.0a2 #if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) static inline void PyThreadState_EnterTracing(PyThreadState *tstate) { tstate->tracing++; #if PY_VERSION_HEX >= 0x030A00A1 tstate->cframe->use_tracing = 0; #else tstate->use_tracing = 0; #endif } #endif // bpo-43760 added PyThreadState_LeaveTracing() to Python 3.11.0a2 #if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) static inline void PyThreadState_LeaveTracing(PyThreadState *tstate) { int use_tracing = (tstate->c_tracefunc != _Py_NULL || tstate->c_profilefunc != _Py_NULL); tstate->tracing--; #if PY_VERSION_HEX >= 0x030A00A1 tstate->cframe->use_tracing = use_tracing; #else tstate->use_tracing = use_tracing; #endif } #endif // bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1 // PyObject_CallNoArgs() added to PyPy 3.9.16-v7.3.11 #if !defined(PyObject_CallNoArgs) && PY_VERSION_HEX < 0x030900A1 static inline PyObject* PyObject_CallNoArgs(PyObject *func) { return PyObject_CallFunctionObjArgs(func, NULL); } #endif // bpo-39245 made PyObject_CallOneArg() public (previously called // _PyObject_CallOneArg) in Python 3.9.0a4 // PyObject_CallOneArg() added to PyPy 3.9.16-v7.3.11 #if !defined(PyObject_CallOneArg) && PY_VERSION_HEX < 0x030900A4 static inline PyObject* PyObject_CallOneArg(PyObject *func, PyObject *arg) { return PyObject_CallFunctionObjArgs(func, arg, NULL); } #endif // bpo-1635741 added PyModule_AddObjectRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 static inline int PyModule_AddObjectRef(PyObject *module, const char *name, PyObject *value) { int res; if (!value && !PyErr_Occurred()) { // PyModule_AddObject() raises TypeError in this case PyErr_SetString(PyExc_SystemError, "PyModule_AddObjectRef() must be called " "with an exception raised if value is NULL"); return -1; } Py_XINCREF(value); res = PyModule_AddObject(module, name, value); if (res < 0) { Py_XDECREF(value); } return res; } #endif // bpo-40024 added PyModule_AddType() to Python 3.9.0a5 #if PY_VERSION_HEX < 0x030900A5 static inline int PyModule_AddType(PyObject *module, PyTypeObject *type) { const char *name, *dot; if (PyType_Ready(type) < 0) { return -1; } // inline _PyType_Name() name = type->tp_name; assert(name != _Py_NULL); dot = strrchr(name, '.'); if (dot != _Py_NULL) { name = dot + 1; } return PyModule_AddObjectRef(module, name, _PyObject_CAST(type)); } #endif // bpo-40241 added PyObject_GC_IsTracked() to Python 3.9.0a6. // bpo-4688 added _PyObject_GC_IS_TRACKED() to Python 2.7.0a2. #if PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) static inline int PyObject_GC_IsTracked(PyObject* obj) { return (PyObject_IS_GC(obj) && _PyObject_GC_IS_TRACKED(obj)); } #endif // bpo-40241 added PyObject_GC_IsFinalized() to Python 3.9.0a6. // bpo-18112 added _PyGCHead_FINALIZED() to Python 3.4.0 final. #if PY_VERSION_HEX < 0x030900A6 && PY_VERSION_HEX >= 0x030400F0 && !defined(PYPY_VERSION) static inline int PyObject_GC_IsFinalized(PyObject *obj) { PyGC_Head *gc = _Py_CAST(PyGC_Head*, obj) - 1; return (PyObject_IS_GC(obj) && _PyGCHead_FINALIZED(gc)); } #endif // bpo-39573 added Py_IS_TYPE() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 && !defined(Py_IS_TYPE) static inline int _Py_IS_TYPE(PyObject *ob, PyTypeObject *type) { return Py_TYPE(ob) == type; } #define Py_IS_TYPE(ob, type) _Py_IS_TYPE(_PyObject_CAST(ob), type) #endif // bpo-46906 added PyFloat_Pack2() and PyFloat_Unpack2() to Python 3.11a7. // bpo-11734 added _PyFloat_Pack2() and _PyFloat_Unpack2() to Python 3.6.0b1. // Python 3.11a2 moved _PyFloat_Pack2() and _PyFloat_Unpack2() to the internal // C API: Python 3.11a2-3.11a6 versions are not supported. #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION) static inline int PyFloat_Pack2(double x, char *p, int le) { return _PyFloat_Pack2(x, (unsigned char*)p, le); } static inline double PyFloat_Unpack2(const char *p, int le) { return _PyFloat_Unpack2((const unsigned char *)p, le); } #endif // bpo-46906 added PyFloat_Pack4(), PyFloat_Pack8(), PyFloat_Unpack4() and // PyFloat_Unpack8() to Python 3.11a7. // Python 3.11a2 moved _PyFloat_Pack4(), _PyFloat_Pack8(), _PyFloat_Unpack4() // and _PyFloat_Unpack8() to the internal C API: Python 3.11a2-3.11a6 versions // are not supported. #if PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION) static inline int PyFloat_Pack4(double x, char *p, int le) { return _PyFloat_Pack4(x, (unsigned char*)p, le); } static inline int PyFloat_Pack8(double x, char *p, int le) { return _PyFloat_Pack8(x, (unsigned char*)p, le); } static inline double PyFloat_Unpack4(const char *p, int le) { return _PyFloat_Unpack4((const unsigned char *)p, le); } static inline double PyFloat_Unpack8(const char *p, int le) { return _PyFloat_Unpack8((const unsigned char *)p, le); } #endif // gh-92154 added PyCode_GetCode() to Python 3.11.0b1 #if PY_VERSION_HEX < 0x030B00B1 && !defined(PYPY_VERSION) static inline PyObject* PyCode_GetCode(PyCodeObject *code) { return Py_NewRef(code->co_code); } #endif // gh-95008 added PyCode_GetVarnames() to Python 3.11.0rc1 #if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) static inline PyObject* PyCode_GetVarnames(PyCodeObject *code) { return Py_NewRef(code->co_varnames); } #endif // gh-95008 added PyCode_GetFreevars() to Python 3.11.0rc1 #if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) static inline PyObject* PyCode_GetFreevars(PyCodeObject *code) { return Py_NewRef(code->co_freevars); } #endif // gh-95008 added PyCode_GetCellvars() to Python 3.11.0rc1 #if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) static inline PyObject* PyCode_GetCellvars(PyCodeObject *code) { return Py_NewRef(code->co_cellvars); } #endif // Py_UNUSED() was added to Python 3.4.0b2. #if PY_VERSION_HEX < 0x030400B2 && !defined(Py_UNUSED) # if defined(__GNUC__) || defined(__clang__) # define Py_UNUSED(name) _unused_ ## name __attribute__((unused)) # else # define Py_UNUSED(name) _unused_ ## name # endif #endif // gh-105922 added PyImport_AddModuleRef() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A0 static inline PyObject* PyImport_AddModuleRef(const char *name) { return Py_XNewRef(PyImport_AddModule(name)); } #endif // gh-105927 added PyWeakref_GetRef() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D0000 static inline int PyWeakref_GetRef(PyObject *ref, PyObject **pobj) { PyObject *obj; if (ref != NULL && !PyWeakref_Check(ref)) { *pobj = NULL; PyErr_SetString(PyExc_TypeError, "expected a weakref"); return -1; } obj = PyWeakref_GetObject(ref); if (obj == NULL) { // SystemError if ref is NULL *pobj = NULL; return -1; } if (obj == Py_None) { *pobj = NULL; return 0; } *pobj = Py_NewRef(obj); return 1; } #endif // bpo-36974 added PY_VECTORCALL_ARGUMENTS_OFFSET to Python 3.8b1 #ifndef PY_VECTORCALL_ARGUMENTS_OFFSET # define PY_VECTORCALL_ARGUMENTS_OFFSET (_Py_CAST(size_t, 1) << (8 * sizeof(size_t) - 1)) #endif // bpo-36974 added PyVectorcall_NARGS() to Python 3.8b1 #if PY_VERSION_HEX < 0x030800B1 static inline Py_ssize_t PyVectorcall_NARGS(size_t n) { return n & ~PY_VECTORCALL_ARGUMENTS_OFFSET; } #endif // gh-105922 added PyObject_Vectorcall() to Python 3.9.0a4 #if PY_VERSION_HEX < 0x030900A4 static inline PyObject* PyObject_Vectorcall(PyObject *callable, PyObject *const *args, size_t nargsf, PyObject *kwnames) { #if PY_VERSION_HEX >= 0x030800B1 && !defined(PYPY_VERSION) // bpo-36974 added _PyObject_Vectorcall() to Python 3.8.0b1 return _PyObject_Vectorcall(callable, args, nargsf, kwnames); #else PyObject *posargs = NULL, *kwargs = NULL; PyObject *res; Py_ssize_t nposargs, nkwargs, i; if (nargsf != 0 && args == NULL) { PyErr_BadInternalCall(); goto error; } if (kwnames != NULL && !PyTuple_Check(kwnames)) { PyErr_BadInternalCall(); goto error; } nposargs = (Py_ssize_t)PyVectorcall_NARGS(nargsf); if (kwnames) { nkwargs = PyTuple_GET_SIZE(kwnames); } else { nkwargs = 0; } posargs = PyTuple_New(nposargs); if (posargs == NULL) { goto error; } if (nposargs) { for (i=0; i < nposargs; i++) { PyTuple_SET_ITEM(posargs, i, Py_NewRef(*args)); args++; } } if (nkwargs) { kwargs = PyDict_New(); if (kwargs == NULL) { goto error; } for (i = 0; i < nkwargs; i++) { PyObject *key = PyTuple_GET_ITEM(kwnames, i); PyObject *value = *args; args++; if (PyDict_SetItem(kwargs, key, value) < 0) { goto error; } } } else { kwargs = NULL; } res = PyObject_Call(callable, posargs, kwargs); Py_DECREF(posargs); Py_XDECREF(kwargs); return res; error: Py_DECREF(posargs); Py_XDECREF(kwargs); return NULL; #endif } #endif // gh-106521 added PyObject_GetOptionalAttr() and // PyObject_GetOptionalAttrString() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyObject_GetOptionalAttr(PyObject *obj, PyObject *attr_name, PyObject **result) { // bpo-32571 added _PyObject_LookupAttr() to Python 3.7.0b1 #if PY_VERSION_HEX >= 0x030700B1 && !defined(PYPY_VERSION) return _PyObject_LookupAttr(obj, attr_name, result); #else *result = PyObject_GetAttr(obj, attr_name); if (*result != NULL) { return 1; } if (!PyErr_Occurred()) { return 0; } if (PyErr_ExceptionMatches(PyExc_AttributeError)) { PyErr_Clear(); return 0; } return -1; #endif } static inline int PyObject_GetOptionalAttrString(PyObject *obj, const char *attr_name, PyObject **result) { PyObject *name_obj; int rc; #if PY_VERSION_HEX >= 0x03000000 name_obj = PyUnicode_FromString(attr_name); #else name_obj = PyString_FromString(attr_name); #endif if (name_obj == NULL) { *result = NULL; return -1; } rc = PyObject_GetOptionalAttr(obj, name_obj, result); Py_DECREF(name_obj); return rc; } #endif // gh-106307 added PyObject_GetOptionalAttr() and // PyMapping_GetOptionalItemString() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyMapping_GetOptionalItem(PyObject *obj, PyObject *key, PyObject **result) { *result = PyObject_GetItem(obj, key); if (*result) { return 1; } if (!PyErr_ExceptionMatches(PyExc_KeyError)) { return -1; } PyErr_Clear(); return 0; } static inline int PyMapping_GetOptionalItemString(PyObject *obj, const char *key, PyObject **result) { PyObject *key_obj; int rc; #if PY_VERSION_HEX >= 0x03000000 key_obj = PyUnicode_FromString(key); #else key_obj = PyString_FromString(key); #endif if (key_obj == NULL) { *result = NULL; return -1; } rc = PyMapping_GetOptionalItem(obj, key_obj, result); Py_DECREF(key_obj); return rc; } #endif // gh-108511 added PyMapping_HasKeyWithError() and // PyMapping_HasKeyStringWithError() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyMapping_HasKeyWithError(PyObject *obj, PyObject *key) { PyObject *res; int rc = PyMapping_GetOptionalItem(obj, key, &res); Py_XDECREF(res); return rc; } static inline int PyMapping_HasKeyStringWithError(PyObject *obj, const char *key) { PyObject *res; int rc = PyMapping_GetOptionalItemString(obj, key, &res); Py_XDECREF(res); return rc; } #endif // gh-108511 added PyObject_HasAttrWithError() and // PyObject_HasAttrStringWithError() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyObject_HasAttrWithError(PyObject *obj, PyObject *attr) { PyObject *res; int rc = PyObject_GetOptionalAttr(obj, attr, &res); Py_XDECREF(res); return rc; } static inline int PyObject_HasAttrStringWithError(PyObject *obj, const char *attr) { PyObject *res; int rc = PyObject_GetOptionalAttrString(obj, attr, &res); Py_XDECREF(res); return rc; } #endif // gh-106004 added PyDict_GetItemRef() and PyDict_GetItemStringRef() // to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyDict_GetItemRef(PyObject *mp, PyObject *key, PyObject **result) { #if PY_VERSION_HEX >= 0x03000000 PyObject *item = PyDict_GetItemWithError(mp, key); #else PyObject *item = _PyDict_GetItemWithError(mp, key); #endif if (item != NULL) { *result = Py_NewRef(item); return 1; // found } if (!PyErr_Occurred()) { *result = NULL; return 0; // not found } *result = NULL; return -1; } static inline int PyDict_GetItemStringRef(PyObject *mp, const char *key, PyObject **result) { int res; #if PY_VERSION_HEX >= 0x03000000 PyObject *key_obj = PyUnicode_FromString(key); #else PyObject *key_obj = PyString_FromString(key); #endif if (key_obj == NULL) { *result = NULL; return -1; } res = PyDict_GetItemRef(mp, key_obj, result); Py_DECREF(key_obj); return res; } #endif // gh-106307 added PyModule_Add() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyModule_Add(PyObject *mod, const char *name, PyObject *value) { int res = PyModule_AddObjectRef(mod, name, value); Py_XDECREF(value); return res; } #endif // gh-108014 added Py_IsFinalizing() to Python 3.13.0a1 // bpo-1856 added _Py_Finalizing to Python 3.2.1b1. // _Py_IsFinalizing() was added to PyPy 7.3.0. #if (0x030201B1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030D00A1) \ && (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x7030000) static inline int Py_IsFinalizing(void) { #if PY_VERSION_HEX >= 0x030700A1 // _Py_IsFinalizing() was added to Python 3.7.0a1. return _Py_IsFinalizing(); #else return (_Py_Finalizing != NULL); #endif } #endif // gh-108323 added PyDict_ContainsString() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyDict_ContainsString(PyObject *op, const char *key) { PyObject *key_obj = PyUnicode_FromString(key); if (key_obj == NULL) { return -1; } int res = PyDict_Contains(op, key_obj); Py_DECREF(key_obj); return res; } #endif // gh-108445 added PyLong_AsInt() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyLong_AsInt(PyObject *obj) { #ifdef PYPY_VERSION long value = PyLong_AsLong(obj); if (value == -1 && PyErr_Occurred()) { return -1; } if (value < (long)INT_MIN || (long)INT_MAX < value) { PyErr_SetString(PyExc_OverflowError, "Python int too large to convert to C int"); return -1; } return (int)value; #else return _PyLong_AsInt(obj); #endif } #endif // gh-107073 added PyObject_VisitManagedDict() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyObject_VisitManagedDict(PyObject *obj, visitproc visit, void *arg) { PyObject **dict = _PyObject_GetDictPtr(obj); if (dict == NULL || *dict == NULL) { return -1; } Py_VISIT(*dict); return 0; } static inline void PyObject_ClearManagedDict(PyObject *obj) { PyObject **dict = _PyObject_GetDictPtr(obj); if (dict == NULL || *dict == NULL) { return; } Py_CLEAR(*dict); } #endif // gh-108867 added PyThreadState_GetUnchecked() to Python 3.13.0a1 // Python 3.5.2 added _PyThreadState_UncheckedGet(). #if PY_VERSION_HEX >= 0x03050200 && PY_VERSION_HEX < 0x030D00A1 static inline PyThreadState* PyThreadState_GetUnchecked(void) { return _PyThreadState_UncheckedGet(); } #endif // gh-110289 added PyUnicode_EqualToUTF8() and PyUnicode_EqualToUTF8AndSize() // to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 static inline int PyUnicode_EqualToUTF8AndSize(PyObject *unicode, const char *str, Py_ssize_t str_len) { Py_ssize_t len; const void *utf8; PyObject *exc_type, *exc_value, *exc_tb; int res; // API cannot report errors so save/restore the exception PyErr_Fetch(&exc_type, &exc_value, &exc_tb); // Python 3.3.0a1 added PyUnicode_AsUTF8AndSize() #if PY_VERSION_HEX >= 0x030300A1 if (PyUnicode_IS_ASCII(unicode)) { utf8 = PyUnicode_DATA(unicode); len = PyUnicode_GET_LENGTH(unicode); } else { utf8 = PyUnicode_AsUTF8AndSize(unicode, &len); if (utf8 == NULL) { // Memory allocation failure. The API cannot report error, // so ignore the exception and return 0. res = 0; goto done; } } if (len != str_len) { res = 0; goto done; } res = (memcmp(utf8, str, (size_t)len) == 0); #else PyObject *bytes = PyUnicode_AsUTF8String(unicode); if (bytes == NULL) { // Memory allocation failure. The API cannot report error, // so ignore the exception and return 0. res = 0; goto done; } #if PY_VERSION_HEX >= 0x03000000 len = PyBytes_GET_SIZE(bytes); utf8 = PyBytes_AS_STRING(bytes); #else len = PyString_GET_SIZE(bytes); utf8 = PyString_AS_STRING(bytes); #endif if (len != str_len) { Py_DECREF(bytes); res = 0; goto done; } res = (memcmp(utf8, str, (size_t)len) == 0); Py_DECREF(bytes); #endif done: PyErr_Restore(exc_type, exc_value, exc_tb); return res; } static inline int PyUnicode_EqualToUTF8(PyObject *unicode, const char *str) { return PyUnicode_EqualToUTF8AndSize(unicode, str, (Py_ssize_t)strlen(str)); } #endif // gh-111138 added PyList_Extend() and PyList_Clear() to Python 3.13.0a2 #if PY_VERSION_HEX < 0x030D00A2 static inline int PyList_Extend(PyObject *list, PyObject *iterable) { return PyList_SetSlice(list, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, iterable); } static inline int PyList_Clear(PyObject *list) { return PyList_SetSlice(list, 0, PY_SSIZE_T_MAX, NULL); } #endif // gh-111262 added PyDict_Pop() and PyDict_PopString() to Python 3.13.0a2 #if PY_VERSION_HEX < 0x030D00A2 static inline int PyDict_Pop(PyObject *dict, PyObject *key, PyObject **result) { PyObject *value; if (!PyDict_Check(dict)) { PyErr_BadInternalCall(); if (result) { *result = NULL; } return -1; } // bpo-16991 added _PyDict_Pop() to Python 3.5.0b2. // Python 3.6.0b3 changed _PyDict_Pop() first argument type to PyObject*. // Python 3.13.0a1 removed _PyDict_Pop(). #if defined(PYPY_VERSION) || PY_VERSION_HEX < 0x030500b2 || PY_VERSION_HEX >= 0x030D0000 value = PyObject_CallMethod(dict, "pop", "O", key); #elif PY_VERSION_HEX < 0x030600b3 value = _PyDict_Pop(_Py_CAST(PyDictObject*, dict), key, NULL); #else value = _PyDict_Pop(dict, key, NULL); #endif if (value == NULL) { if (result) { *result = NULL; } if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_KeyError)) { return -1; } PyErr_Clear(); return 0; } if (result) { *result = value; } else { Py_DECREF(value); } return 1; } static inline int PyDict_PopString(PyObject *dict, const char *key, PyObject **result) { PyObject *key_obj = PyUnicode_FromString(key); if (key_obj == NULL) { if (result != NULL) { *result = NULL; } return -1; } int res = PyDict_Pop(dict, key_obj, result); Py_DECREF(key_obj); return res; } #endif #if PY_VERSION_HEX < 0x030200A4 // Python 3.2.0a4 added Py_hash_t type typedef Py_ssize_t Py_hash_t; #endif // gh-111545 added Py_HashPointer() to Python 3.13.0a3 #if PY_VERSION_HEX < 0x030D00A3 static inline Py_hash_t Py_HashPointer(const void *ptr) { #if PY_VERSION_HEX >= 0x030900A4 && !defined(PYPY_VERSION) return _Py_HashPointer(ptr); #else return _Py_HashPointer(_Py_CAST(void*, ptr)); #endif } #endif // Python 3.13a4 added a PyTime API. // Use the private API added to Python 3.5. #if PY_VERSION_HEX < 0x030D00A4 && PY_VERSION_HEX >= 0x03050000 typedef _PyTime_t PyTime_t; #define PyTime_MIN _PyTime_MIN #define PyTime_MAX _PyTime_MAX static inline double PyTime_AsSecondsDouble(PyTime_t t) { return _PyTime_AsSecondsDouble(t); } static inline int PyTime_Monotonic(PyTime_t *result) { return _PyTime_GetMonotonicClockWithInfo(result, NULL); } static inline int PyTime_Time(PyTime_t *result) { return _PyTime_GetSystemClockWithInfo(result, NULL); } static inline int PyTime_PerfCounter(PyTime_t *result) { #if PY_VERSION_HEX >= 0x03070000 && !defined(PYPY_VERSION) return _PyTime_GetPerfCounterWithInfo(result, NULL); #elif PY_VERSION_HEX >= 0x03070000 // Call time.perf_counter_ns() and convert Python int object to PyTime_t. // Cache time.perf_counter_ns() function for best performance. static PyObject *func = NULL; if (func == NULL) { PyObject *mod = PyImport_ImportModule("time"); if (mod == NULL) { return -1; } func = PyObject_GetAttrString(mod, "perf_counter_ns"); Py_DECREF(mod); if (func == NULL) { return -1; } } PyObject *res = PyObject_CallNoArgs(func); if (res == NULL) { return -1; } long long value = PyLong_AsLongLong(res); Py_DECREF(res); if (value == -1 && PyErr_Occurred()) { return -1; } Py_BUILD_ASSERT(sizeof(value) >= sizeof(PyTime_t)); *result = (PyTime_t)value; return 0; #else // Call time.perf_counter() and convert C double to PyTime_t. // Cache time.perf_counter() function for best performance. static PyObject *func = NULL; if (func == NULL) { PyObject *mod = PyImport_ImportModule("time"); if (mod == NULL) { return -1; } func = PyObject_GetAttrString(mod, "perf_counter"); Py_DECREF(mod); if (func == NULL) { return -1; } } PyObject *res = PyObject_CallNoArgs(func); if (res == NULL) { return -1; } double d = PyFloat_AsDouble(res); Py_DECREF(res); if (d == -1.0 && PyErr_Occurred()) { return -1; } // Avoid floor() to avoid having to link to libm *result = (PyTime_t)(d * 1e9); return 0; #endif } #endif // gh-111389 added hash constants to Python 3.13.0a5. These constants were // added first as private macros to Python 3.4.0b1 and PyPy 7.3.8. #if (!defined(PyHASH_BITS) \ && ((!defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x030400B1) \ || (defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x03070000 \ && PYPY_VERSION_NUM >= 0x07030800))) # define PyHASH_BITS _PyHASH_BITS # define PyHASH_MODULUS _PyHASH_MODULUS # define PyHASH_INF _PyHASH_INF # define PyHASH_IMAG _PyHASH_IMAG #endif // gh-111545 added Py_GetConstant() and Py_GetConstantBorrowed() // to Python 3.13.0a6 #if PY_VERSION_HEX < 0x030D00A6 && !defined(Py_CONSTANT_NONE) #define Py_CONSTANT_NONE 0 #define Py_CONSTANT_FALSE 1 #define Py_CONSTANT_TRUE 2 #define Py_CONSTANT_ELLIPSIS 3 #define Py_CONSTANT_NOT_IMPLEMENTED 4 #define Py_CONSTANT_ZERO 5 #define Py_CONSTANT_ONE 6 #define Py_CONSTANT_EMPTY_STR 7 #define Py_CONSTANT_EMPTY_BYTES 8 #define Py_CONSTANT_EMPTY_TUPLE 9 static inline PyObject* Py_GetConstant(unsigned int constant_id) { static PyObject* constants[Py_CONSTANT_EMPTY_TUPLE + 1] = {NULL}; if (constants[Py_CONSTANT_NONE] == NULL) { constants[Py_CONSTANT_NONE] = Py_None; constants[Py_CONSTANT_FALSE] = Py_False; constants[Py_CONSTANT_TRUE] = Py_True; constants[Py_CONSTANT_ELLIPSIS] = Py_Ellipsis; constants[Py_CONSTANT_NOT_IMPLEMENTED] = Py_NotImplemented; constants[Py_CONSTANT_ZERO] = PyLong_FromLong(0); if (constants[Py_CONSTANT_ZERO] == NULL) { goto fatal_error; } constants[Py_CONSTANT_ONE] = PyLong_FromLong(1); if (constants[Py_CONSTANT_ONE] == NULL) { goto fatal_error; } constants[Py_CONSTANT_EMPTY_STR] = PyUnicode_FromStringAndSize("", 0); if (constants[Py_CONSTANT_EMPTY_STR] == NULL) { goto fatal_error; } constants[Py_CONSTANT_EMPTY_BYTES] = PyBytes_FromStringAndSize("", 0); if (constants[Py_CONSTANT_EMPTY_BYTES] == NULL) { goto fatal_error; } constants[Py_CONSTANT_EMPTY_TUPLE] = PyTuple_New(0); if (constants[Py_CONSTANT_EMPTY_TUPLE] == NULL) { goto fatal_error; } // goto dance to avoid compiler warnings about Py_FatalError() goto init_done; fatal_error: // This case should never happen Py_FatalError("Py_GetConstant() failed to get constants"); } init_done: if (constant_id <= Py_CONSTANT_EMPTY_TUPLE) { return Py_NewRef(constants[constant_id]); } else { PyErr_BadInternalCall(); return NULL; } } static inline PyObject* Py_GetConstantBorrowed(unsigned int constant_id) { PyObject *obj = Py_GetConstant(constant_id); Py_XDECREF(obj); return obj; } #endif // gh-114329 added PyList_GetItemRef() to Python 3.13.0a4 #if PY_VERSION_HEX < 0x030D00A4 static inline PyObject * PyList_GetItemRef(PyObject *op, Py_ssize_t index) { PyObject *item = PyList_GetItem(op, index); Py_XINCREF(item); return item; } #endif // gh-114329 added PyList_GetItemRef() to Python 3.13.0a4 #if PY_VERSION_HEX < 0x030D00A4 static inline int PyDict_SetDefaultRef(PyObject *d, PyObject *key, PyObject *default_value, PyObject **result) { PyObject *value; if (PyDict_GetItemRef(d, key, &value) < 0) { // get error if (result) { *result = NULL; } return -1; } if (value != NULL) { // present if (result) { *result = value; } else { Py_DECREF(value); } return 1; } // missing: set the item if (PyDict_SetItem(d, key, default_value) < 0) { // set error if (result) { *result = NULL; } return -1; } if (result) { *result = Py_NewRef(default_value); } return 0; } #endif #if PY_VERSION_HEX < 0x030D00B3 # define Py_BEGIN_CRITICAL_SECTION(op) { # define Py_END_CRITICAL_SECTION() } # define Py_BEGIN_CRITICAL_SECTION2(a, b) { # define Py_END_CRITICAL_SECTION2() } #endif #if PY_VERSION_HEX < 0x030E0000 && PY_VERSION_HEX >= 0x03060000 && !defined(PYPY_VERSION) typedef struct PyUnicodeWriter PyUnicodeWriter; static inline void PyUnicodeWriter_Discard(PyUnicodeWriter *writer) { _PyUnicodeWriter_Dealloc((_PyUnicodeWriter*)writer); PyMem_Free(writer); } static inline PyUnicodeWriter* PyUnicodeWriter_Create(Py_ssize_t length) { if (length < 0) { PyErr_SetString(PyExc_ValueError, "length must be positive"); return NULL; } const size_t size = sizeof(_PyUnicodeWriter); PyUnicodeWriter *pub_writer = (PyUnicodeWriter *)PyMem_Malloc(size); if (pub_writer == _Py_NULL) { PyErr_NoMemory(); return _Py_NULL; } _PyUnicodeWriter *writer = (_PyUnicodeWriter *)pub_writer; _PyUnicodeWriter_Init(writer); if (_PyUnicodeWriter_Prepare(writer, length, 127) < 0) { PyUnicodeWriter_Discard(pub_writer); return NULL; } writer->overallocate = 1; return pub_writer; } static inline PyObject* PyUnicodeWriter_Finish(PyUnicodeWriter *writer) { PyObject *str = _PyUnicodeWriter_Finish((_PyUnicodeWriter*)writer); assert(((_PyUnicodeWriter*)writer)->buffer == NULL); PyMem_Free(writer); return str; } static inline int PyUnicodeWriter_WriteChar(PyUnicodeWriter *writer, Py_UCS4 ch) { if (ch > 0x10ffff) { PyErr_SetString(PyExc_ValueError, "character must be in range(0x110000)"); return -1; } return _PyUnicodeWriter_WriteChar((_PyUnicodeWriter*)writer, ch); } static inline int PyUnicodeWriter_WriteStr(PyUnicodeWriter *writer, PyObject *obj) { PyObject *str = PyObject_Str(obj); if (str == NULL) { return -1; } int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); Py_DECREF(str); return res; } static inline int PyUnicodeWriter_WriteRepr(PyUnicodeWriter *writer, PyObject *obj) { PyObject *str = PyObject_Repr(obj); if (str == NULL) { return -1; } int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); Py_DECREF(str); return res; } static inline int PyUnicodeWriter_WriteUTF8(PyUnicodeWriter *writer, const char *str, Py_ssize_t size) { if (size < 0) { size = (Py_ssize_t)strlen(str); } PyObject *str_obj = PyUnicode_FromStringAndSize(str, size); if (str_obj == _Py_NULL) { return -1; } int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str_obj); Py_DECREF(str_obj); return res; } static inline int PyUnicodeWriter_WriteASCII(PyUnicodeWriter *writer, const char *str, Py_ssize_t size) { if (size < 0) { size = (Py_ssize_t)strlen(str); } return _PyUnicodeWriter_WriteASCIIString((_PyUnicodeWriter*)writer, str, size); } static inline int PyUnicodeWriter_WriteWideChar(PyUnicodeWriter *writer, const wchar_t *str, Py_ssize_t size) { if (size < 0) { size = (Py_ssize_t)wcslen(str); } PyObject *str_obj = PyUnicode_FromWideChar(str, size); if (str_obj == _Py_NULL) { return -1; } int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str_obj); Py_DECREF(str_obj); return res; } static inline int PyUnicodeWriter_WriteSubstring(PyUnicodeWriter *writer, PyObject *str, Py_ssize_t start, Py_ssize_t end) { if (!PyUnicode_Check(str)) { PyErr_Format(PyExc_TypeError, "expect str, not %s", Py_TYPE(str)->tp_name); return -1; } if (start < 0 || start > end) { PyErr_Format(PyExc_ValueError, "invalid start argument"); return -1; } if (end > PyUnicode_GET_LENGTH(str)) { PyErr_Format(PyExc_ValueError, "invalid end argument"); return -1; } return _PyUnicodeWriter_WriteSubstring((_PyUnicodeWriter*)writer, str, start, end); } static inline int PyUnicodeWriter_Format(PyUnicodeWriter *writer, const char *format, ...) { va_list vargs; va_start(vargs, format); PyObject *str = PyUnicode_FromFormatV(format, vargs); va_end(vargs); if (str == _Py_NULL) { return -1; } int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); Py_DECREF(str); return res; } #endif // PY_VERSION_HEX < 0x030E0000 // gh-116560 added PyLong_GetSign() to Python 3.14.0a0 #if PY_VERSION_HEX < 0x030E00A0 static inline int PyLong_GetSign(PyObject *obj, int *sign) { if (!PyLong_Check(obj)) { PyErr_Format(PyExc_TypeError, "expect int, got %s", Py_TYPE(obj)->tp_name); return -1; } *sign = _PyLong_Sign(obj); return 0; } #endif // gh-126061 added PyLong_IsPositive/Negative/Zero() to Python in 3.14.0a2 #if PY_VERSION_HEX < 0x030E00A2 static inline int PyLong_IsPositive(PyObject *obj) { if (!PyLong_Check(obj)) { PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); return -1; } return _PyLong_Sign(obj) == 1; } static inline int PyLong_IsNegative(PyObject *obj) { if (!PyLong_Check(obj)) { PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); return -1; } return _PyLong_Sign(obj) == -1; } static inline int PyLong_IsZero(PyObject *obj) { if (!PyLong_Check(obj)) { PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); return -1; } return _PyLong_Sign(obj) == 0; } #endif // gh-124502 added PyUnicode_Equal() to Python 3.14.0a0 #if PY_VERSION_HEX < 0x030E00A0 static inline int PyUnicode_Equal(PyObject *str1, PyObject *str2) { if (!PyUnicode_Check(str1)) { PyErr_Format(PyExc_TypeError, "first argument must be str, not %s", Py_TYPE(str1)->tp_name); return -1; } if (!PyUnicode_Check(str2)) { PyErr_Format(PyExc_TypeError, "second argument must be str, not %s", Py_TYPE(str2)->tp_name); return -1; } #if PY_VERSION_HEX >= 0x030d0000 && !defined(PYPY_VERSION) PyAPI_FUNC(int) _PyUnicode_Equal(PyObject *str1, PyObject *str2); return _PyUnicode_Equal(str1, str2); #elif PY_VERSION_HEX >= 0x03060000 && !defined(PYPY_VERSION) return _PyUnicode_EQ(str1, str2); #elif PY_VERSION_HEX >= 0x03090000 && defined(PYPY_VERSION) return _PyUnicode_EQ(str1, str2); #else return (PyUnicode_Compare(str1, str2) == 0); #endif } #endif // gh-121645 added PyBytes_Join() to Python 3.14.0a0 #if PY_VERSION_HEX < 0x030E00A0 static inline PyObject* PyBytes_Join(PyObject *sep, PyObject *iterable) { return _PyBytes_Join(sep, iterable); } #endif #if PY_VERSION_HEX < 0x030E00A0 static inline Py_hash_t Py_HashBuffer(const void *ptr, Py_ssize_t len) { #if PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) PyAPI_FUNC(Py_hash_t) _Py_HashBytes(const void *src, Py_ssize_t len); return _Py_HashBytes(ptr, len); #else Py_hash_t hash; PyObject *bytes = PyBytes_FromStringAndSize((const char*)ptr, len); if (bytes == NULL) { return -1; } hash = PyObject_Hash(bytes); Py_DECREF(bytes); return hash; #endif } #endif #if PY_VERSION_HEX < 0x030E00A0 static inline int PyIter_NextItem(PyObject *iter, PyObject **item) { iternextfunc tp_iternext; assert(iter != NULL); assert(item != NULL); tp_iternext = Py_TYPE(iter)->tp_iternext; if (tp_iternext == NULL) { *item = NULL; PyErr_Format(PyExc_TypeError, "expected an iterator, got '%s'", Py_TYPE(iter)->tp_name); return -1; } if ((*item = tp_iternext(iter))) { return 1; } if (!PyErr_Occurred()) { return 0; } if (PyErr_ExceptionMatches(PyExc_StopIteration)) { PyErr_Clear(); return 0; } return -1; } #endif #if PY_VERSION_HEX < 0x030E00A0 static inline PyObject* PyLong_FromInt32(int32_t value) { Py_BUILD_ASSERT(sizeof(long) >= 4); return PyLong_FromLong(value); } static inline PyObject* PyLong_FromInt64(int64_t value) { Py_BUILD_ASSERT(sizeof(long long) >= 8); return PyLong_FromLongLong(value); } static inline PyObject* PyLong_FromUInt32(uint32_t value) { Py_BUILD_ASSERT(sizeof(unsigned long) >= 4); return PyLong_FromUnsignedLong(value); } static inline PyObject* PyLong_FromUInt64(uint64_t value) { Py_BUILD_ASSERT(sizeof(unsigned long long) >= 8); return PyLong_FromUnsignedLongLong(value); } static inline int PyLong_AsInt32(PyObject *obj, int32_t *pvalue) { Py_BUILD_ASSERT(sizeof(int) == 4); int value = PyLong_AsInt(obj); if (value == -1 && PyErr_Occurred()) { return -1; } *pvalue = (int32_t)value; return 0; } static inline int PyLong_AsInt64(PyObject *obj, int64_t *pvalue) { Py_BUILD_ASSERT(sizeof(long long) == 8); long long value = PyLong_AsLongLong(obj); if (value == -1 && PyErr_Occurred()) { return -1; } *pvalue = (int64_t)value; return 0; } static inline int PyLong_AsUInt32(PyObject *obj, uint32_t *pvalue) { Py_BUILD_ASSERT(sizeof(long) >= 4); unsigned long value = PyLong_AsUnsignedLong(obj); if (value == (unsigned long)-1 && PyErr_Occurred()) { return -1; } #if SIZEOF_LONG > 4 if ((unsigned long)UINT32_MAX < value) { PyErr_SetString(PyExc_OverflowError, "Python int too large to convert to C uint32_t"); return -1; } #endif *pvalue = (uint32_t)value; return 0; } static inline int PyLong_AsUInt64(PyObject *obj, uint64_t *pvalue) { Py_BUILD_ASSERT(sizeof(long long) == 8); unsigned long long value = PyLong_AsUnsignedLongLong(obj); if (value == (unsigned long long)-1 && PyErr_Occurred()) { return -1; } *pvalue = (uint64_t)value; return 0; } #endif // gh-102471 added import and export API for integers to 3.14.0a2. #if PY_VERSION_HEX < 0x030E00A2 && PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) // Helpers to access PyLongObject internals. static inline void _PyLong_SetSignAndDigitCount(PyLongObject *op, int sign, Py_ssize_t size) { #if PY_VERSION_HEX >= 0x030C0000 op->long_value.lv_tag = (uintptr_t)(1 - sign) | ((uintptr_t)(size) << 3); #elif PY_VERSION_HEX >= 0x030900A4 Py_SET_SIZE(op, sign * size); #else Py_SIZE(op) = sign * size; #endif } static inline Py_ssize_t _PyLong_DigitCount(const PyLongObject *op) { #if PY_VERSION_HEX >= 0x030C0000 return (Py_ssize_t)(op->long_value.lv_tag >> 3); #else return _PyLong_Sign((PyObject*)op) < 0 ? -Py_SIZE(op) : Py_SIZE(op); #endif } static inline digit* _PyLong_GetDigits(const PyLongObject *op) { #if PY_VERSION_HEX >= 0x030C0000 return (digit*)(op->long_value.ob_digit); #else return (digit*)(op->ob_digit); #endif } typedef struct PyLongLayout { uint8_t bits_per_digit; uint8_t digit_size; int8_t digits_order; int8_t digit_endianness; } PyLongLayout; typedef struct PyLongExport { int64_t value; uint8_t negative; Py_ssize_t ndigits; const void *digits; Py_uintptr_t _reserved; } PyLongExport; typedef struct PyLongWriter PyLongWriter; static inline const PyLongLayout* PyLong_GetNativeLayout(void) { static const PyLongLayout PyLong_LAYOUT = { PyLong_SHIFT, sizeof(digit), -1, // least significant first PY_LITTLE_ENDIAN ? -1 : 1, }; return &PyLong_LAYOUT; } static inline int PyLong_Export(PyObject *obj, PyLongExport *export_long) { if (!PyLong_Check(obj)) { memset(export_long, 0, sizeof(*export_long)); PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); return -1; } // Fast-path: try to convert to a int64_t PyLongObject *self = (PyLongObject*)obj; int overflow; #if SIZEOF_LONG == 8 long value = PyLong_AsLongAndOverflow(obj, &overflow); #else // Windows has 32-bit long, so use 64-bit long long instead long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); #endif Py_BUILD_ASSERT(sizeof(value) == sizeof(int64_t)); // the function cannot fail since obj is a PyLongObject assert(!(value == -1 && PyErr_Occurred())); if (!overflow) { export_long->value = value; export_long->negative = 0; export_long->ndigits = 0; export_long->digits = 0; export_long->_reserved = 0; } else { export_long->value = 0; export_long->negative = _PyLong_Sign(obj) < 0; export_long->ndigits = _PyLong_DigitCount(self); if (export_long->ndigits == 0) { export_long->ndigits = 1; } export_long->digits = _PyLong_GetDigits(self); export_long->_reserved = (Py_uintptr_t)Py_NewRef(obj); } return 0; } static inline void PyLong_FreeExport(PyLongExport *export_long) { PyObject *obj = (PyObject*)export_long->_reserved; if (obj) { export_long->_reserved = 0; Py_DECREF(obj); } } static inline PyLongWriter* PyLongWriter_Create(int negative, Py_ssize_t ndigits, void **digits) { if (ndigits <= 0) { PyErr_SetString(PyExc_ValueError, "ndigits must be positive"); return NULL; } assert(digits != NULL); PyLongObject *obj = _PyLong_New(ndigits); if (obj == NULL) { return NULL; } _PyLong_SetSignAndDigitCount(obj, negative?-1:1, ndigits); *digits = _PyLong_GetDigits(obj); return (PyLongWriter*)obj; } static inline void PyLongWriter_Discard(PyLongWriter *writer) { PyLongObject *obj = (PyLongObject *)writer; assert(Py_REFCNT(obj) == 1); Py_DECREF(obj); } static inline PyObject* PyLongWriter_Finish(PyLongWriter *writer) { PyObject *obj = (PyObject *)writer; PyLongObject *self = (PyLongObject*)obj; Py_ssize_t j = _PyLong_DigitCount(self); Py_ssize_t i = j; int sign = _PyLong_Sign(obj); assert(Py_REFCNT(obj) == 1); // Normalize and get singleton if possible while (i > 0 && _PyLong_GetDigits(self)[i-1] == 0) { --i; } if (i != j) { if (i == 0) { sign = 0; } _PyLong_SetSignAndDigitCount(self, sign, i); } if (i <= 1) { long val = sign * (long)(_PyLong_GetDigits(self)[0]); Py_DECREF(obj); return PyLong_FromLong(val); } return obj; } #endif #if PY_VERSION_HEX < 0x030C00A3 # define Py_T_SHORT T_SHORT # define Py_T_INT T_INT # define Py_T_LONG T_LONG # define Py_T_FLOAT T_FLOAT # define Py_T_DOUBLE T_DOUBLE # define Py_T_STRING T_STRING # define _Py_T_OBJECT T_OBJECT # define Py_T_CHAR T_CHAR # define Py_T_BYTE T_BYTE # define Py_T_UBYTE T_UBYTE # define Py_T_USHORT T_USHORT # define Py_T_UINT T_UINT # define Py_T_ULONG T_ULONG # define Py_T_STRING_INPLACE T_STRING_INPLACE # define Py_T_BOOL T_BOOL # define Py_T_OBJECT_EX T_OBJECT_EX # define Py_T_LONGLONG T_LONGLONG # define Py_T_ULONGLONG T_ULONGLONG # define Py_T_PYSSIZET T_PYSSIZET # if PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) # define _Py_T_NONE T_NONE # endif # define Py_READONLY READONLY # define Py_AUDIT_READ READ_RESTRICTED # define _Py_WRITE_RESTRICTED PY_WRITE_RESTRICTED #endif // gh-127350 added Py_fopen() and Py_fclose() to Python 3.14a4 #if PY_VERSION_HEX < 0x030E00A4 static inline FILE* Py_fopen(PyObject *path, const char *mode) { #if 0x030400A2 <= PY_VERSION_HEX && !defined(PYPY_VERSION) PyAPI_FUNC(FILE*) _Py_fopen_obj(PyObject *path, const char *mode); return _Py_fopen_obj(path, mode); #else FILE *f; PyObject *bytes; #if PY_VERSION_HEX >= 0x03000000 if (!PyUnicode_FSConverter(path, &bytes)) { return NULL; } #else if (!PyString_Check(path)) { PyErr_SetString(PyExc_TypeError, "except str"); return NULL; } bytes = Py_NewRef(path); #endif const char *path_bytes = PyBytes_AS_STRING(bytes); f = fopen(path_bytes, mode); Py_DECREF(bytes); if (f == NULL) { PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, path); return NULL; } return f; #endif } static inline int Py_fclose(FILE *file) { return fclose(file); } #endif #if 0x03090000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030E0000 && !defined(PYPY_VERSION) static inline PyObject* PyConfig_Get(const char *name) { typedef enum { _PyConfig_MEMBER_INT, _PyConfig_MEMBER_UINT, _PyConfig_MEMBER_ULONG, _PyConfig_MEMBER_BOOL, _PyConfig_MEMBER_WSTR, _PyConfig_MEMBER_WSTR_OPT, _PyConfig_MEMBER_WSTR_LIST, } PyConfigMemberType; typedef struct { const char *name; size_t offset; PyConfigMemberType type; const char *sys_attr; } PyConfigSpec; #define PYTHONCAPI_COMPAT_SPEC(MEMBER, TYPE, sys_attr) \ {#MEMBER, offsetof(PyConfig, MEMBER), \ _PyConfig_MEMBER_##TYPE, sys_attr} static const PyConfigSpec config_spec[] = { PYTHONCAPI_COMPAT_SPEC(argv, WSTR_LIST, "argv"), PYTHONCAPI_COMPAT_SPEC(base_exec_prefix, WSTR_OPT, "base_exec_prefix"), PYTHONCAPI_COMPAT_SPEC(base_executable, WSTR_OPT, "_base_executable"), PYTHONCAPI_COMPAT_SPEC(base_prefix, WSTR_OPT, "base_prefix"), PYTHONCAPI_COMPAT_SPEC(bytes_warning, UINT, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(exec_prefix, WSTR_OPT, "exec_prefix"), PYTHONCAPI_COMPAT_SPEC(executable, WSTR_OPT, "executable"), PYTHONCAPI_COMPAT_SPEC(inspect, BOOL, _Py_NULL), #if 0x030C0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(int_max_str_digits, UINT, _Py_NULL), #endif PYTHONCAPI_COMPAT_SPEC(interactive, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(module_search_paths, WSTR_LIST, "path"), PYTHONCAPI_COMPAT_SPEC(optimization_level, UINT, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(parser_debug, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(platlibdir, WSTR, "platlibdir"), PYTHONCAPI_COMPAT_SPEC(prefix, WSTR_OPT, "prefix"), PYTHONCAPI_COMPAT_SPEC(pycache_prefix, WSTR_OPT, "pycache_prefix"), PYTHONCAPI_COMPAT_SPEC(quiet, BOOL, _Py_NULL), #if 0x030B0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(stdlib_dir, WSTR_OPT, "_stdlib_dir"), #endif PYTHONCAPI_COMPAT_SPEC(use_environment, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(verbose, UINT, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(warnoptions, WSTR_LIST, "warnoptions"), PYTHONCAPI_COMPAT_SPEC(write_bytecode, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(xoptions, WSTR_LIST, "_xoptions"), PYTHONCAPI_COMPAT_SPEC(buffered_stdio, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(check_hash_pycs_mode, WSTR, _Py_NULL), #if 0x030B0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(code_debug_ranges, BOOL, _Py_NULL), #endif PYTHONCAPI_COMPAT_SPEC(configure_c_stdio, BOOL, _Py_NULL), #if 0x030D0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(cpu_count, INT, _Py_NULL), #endif PYTHONCAPI_COMPAT_SPEC(dev_mode, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(dump_refs, BOOL, _Py_NULL), #if 0x030B0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(dump_refs_file, WSTR_OPT, _Py_NULL), #endif #ifdef Py_GIL_DISABLED PYTHONCAPI_COMPAT_SPEC(enable_gil, INT, _Py_NULL), #endif PYTHONCAPI_COMPAT_SPEC(faulthandler, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(filesystem_encoding, WSTR, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(filesystem_errors, WSTR, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(hash_seed, ULONG, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(home, WSTR_OPT, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(import_time, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(install_signal_handlers, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(isolated, BOOL, _Py_NULL), #ifdef MS_WINDOWS PYTHONCAPI_COMPAT_SPEC(legacy_windows_stdio, BOOL, _Py_NULL), #endif PYTHONCAPI_COMPAT_SPEC(malloc_stats, BOOL, _Py_NULL), #if 0x030A0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(orig_argv, WSTR_LIST, "orig_argv"), #endif PYTHONCAPI_COMPAT_SPEC(parse_argv, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(pathconfig_warnings, BOOL, _Py_NULL), #if 0x030C0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(perf_profiling, UINT, _Py_NULL), #endif PYTHONCAPI_COMPAT_SPEC(program_name, WSTR, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(run_command, WSTR_OPT, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(run_filename, WSTR_OPT, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(run_module, WSTR_OPT, _Py_NULL), #if 0x030B0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(safe_path, BOOL, _Py_NULL), #endif PYTHONCAPI_COMPAT_SPEC(show_ref_count, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(site_import, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(skip_source_first_line, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(stdio_encoding, WSTR, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(stdio_errors, WSTR, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(tracemalloc, UINT, _Py_NULL), #if 0x030B0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(use_frozen_modules, BOOL, _Py_NULL), #endif PYTHONCAPI_COMPAT_SPEC(use_hash_seed, BOOL, _Py_NULL), PYTHONCAPI_COMPAT_SPEC(user_site_directory, BOOL, _Py_NULL), #if 0x030A0000 <= PY_VERSION_HEX PYTHONCAPI_COMPAT_SPEC(warn_default_encoding, BOOL, _Py_NULL), #endif }; #undef PYTHONCAPI_COMPAT_SPEC const PyConfigSpec *spec; int found = 0; for (size_t i=0; i < sizeof(config_spec) / sizeof(config_spec[0]); i++) { spec = &config_spec[i]; if (strcmp(spec->name, name) == 0) { found = 1; break; } } if (found) { if (spec->sys_attr != NULL) { PyObject *value = PySys_GetObject(spec->sys_attr); if (value == NULL) { PyErr_Format(PyExc_RuntimeError, "lost sys.%s", spec->sys_attr); return NULL; } return Py_NewRef(value); } PyAPI_FUNC(const PyConfig*) _Py_GetConfig(void); const PyConfig *config = _Py_GetConfig(); void *member = (char *)config + spec->offset; switch (spec->type) { case _PyConfig_MEMBER_INT: case _PyConfig_MEMBER_UINT: { int value = *(int *)member; return PyLong_FromLong(value); } case _PyConfig_MEMBER_BOOL: { int value = *(int *)member; return PyBool_FromLong(value != 0); } case _PyConfig_MEMBER_ULONG: { unsigned long value = *(unsigned long *)member; return PyLong_FromUnsignedLong(value); } case _PyConfig_MEMBER_WSTR: case _PyConfig_MEMBER_WSTR_OPT: { wchar_t *wstr = *(wchar_t **)member; if (wstr != NULL) { return PyUnicode_FromWideChar(wstr, -1); } else { return Py_NewRef(Py_None); } } case _PyConfig_MEMBER_WSTR_LIST: { const PyWideStringList *list = (const PyWideStringList *)member; PyObject *tuple = PyTuple_New(list->length); if (tuple == NULL) { return NULL; } for (Py_ssize_t i = 0; i < list->length; i++) { PyObject *item = PyUnicode_FromWideChar(list->items[i], -1); if (item == NULL) { Py_DECREF(tuple); return NULL; } PyTuple_SET_ITEM(tuple, i, item); } return tuple; } default: Py_UNREACHABLE(); } } PyErr_Format(PyExc_ValueError, "unknown config option name: %s", name); return NULL; } static inline int PyConfig_GetInt(const char *name, int *value) { PyObject *obj = PyConfig_Get(name); if (obj == NULL) { return -1; } if (!PyLong_Check(obj)) { Py_DECREF(obj); PyErr_Format(PyExc_TypeError, "config option %s is not an int", name); return -1; } int as_int = PyLong_AsInt(obj); Py_DECREF(obj); if (as_int == -1 && PyErr_Occurred()) { PyErr_Format(PyExc_OverflowError, "config option %s value does not fit into a C int", name); return -1; } *value = as_int; return 0; } #endif // PY_VERSION_HEX > 0x03090000 && !defined(PYPY_VERSION) // gh-133144 added PyUnstable_Object_IsUniquelyReferenced() to Python 3.14.0b1. // Adapted from _PyObject_IsUniquelyReferenced() implementation. #if PY_VERSION_HEX < 0x030E00B0 static inline int PyUnstable_Object_IsUniquelyReferenced(PyObject *obj) { #if !defined(Py_GIL_DISABLED) return Py_REFCNT(obj) == 1; #else // NOTE: the entire ob_ref_shared field must be zero, including flags, to // ensure that other threads cannot concurrently create new references to // this object. return (_Py_IsOwnedByCurrentThread(obj) && _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local) == 1 && _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared) == 0); #endif } #endif #if PY_VERSION_HEX < 0x030F0000 static inline PyObject* PySys_GetAttrString(const char *name) { #if PY_VERSION_HEX >= 0x03000000 PyObject *value = Py_XNewRef(PySys_GetObject(name)); #else PyObject *value = Py_XNewRef(PySys_GetObject((char*)name)); #endif if (value != NULL) { return value; } if (!PyErr_Occurred()) { PyErr_Format(PyExc_RuntimeError, "lost sys.%s", name); } return NULL; } static inline PyObject* PySys_GetAttr(PyObject *name) { #if PY_VERSION_HEX >= 0x03000000 const char *name_str = PyUnicode_AsUTF8(name); #else const char *name_str = PyString_AsString(name); #endif if (name_str == NULL) { return NULL; } return PySys_GetAttrString(name_str); } static inline int PySys_GetOptionalAttrString(const char *name, PyObject **value) { #if PY_VERSION_HEX >= 0x03000000 *value = Py_XNewRef(PySys_GetObject(name)); #else *value = Py_XNewRef(PySys_GetObject((char*)name)); #endif if (*value != NULL) { return 1; } return 0; } static inline int PySys_GetOptionalAttr(PyObject *name, PyObject **value) { #if PY_VERSION_HEX >= 0x03000000 const char *name_str = PyUnicode_AsUTF8(name); #else const char *name_str = PyString_AsString(name); #endif if (name_str == NULL) { *value = NULL; return -1; } return PySys_GetOptionalAttrString(name_str, value); } #endif // PY_VERSION_HEX < 0x030F00A1 #if PY_VERSION_HEX < 0x030F00A1 typedef struct PyBytesWriter { char small_buffer[256]; PyObject *obj; Py_ssize_t size; } PyBytesWriter; static inline Py_ssize_t _PyBytesWriter_GetAllocated(PyBytesWriter *writer) { if (writer->obj == NULL) { return sizeof(writer->small_buffer); } else { return PyBytes_GET_SIZE(writer->obj); } } static inline int _PyBytesWriter_Resize_impl(PyBytesWriter *writer, Py_ssize_t size, int resize) { int overallocate = resize; assert(size >= 0); if (size <= _PyBytesWriter_GetAllocated(writer)) { return 0; } if (overallocate) { #ifdef MS_WINDOWS /* On Windows, overallocate by 50% is the best factor */ if (size <= (PY_SSIZE_T_MAX - size / 2)) { size += size / 2; } #else /* On Linux, overallocate by 25% is the best factor */ if (size <= (PY_SSIZE_T_MAX - size / 4)) { size += size / 4; } #endif } if (writer->obj != NULL) { if (_PyBytes_Resize(&writer->obj, size)) { return -1; } assert(writer->obj != NULL); } else { writer->obj = PyBytes_FromStringAndSize(NULL, size); if (writer->obj == NULL) { return -1; } if (resize) { assert((size_t)size > sizeof(writer->small_buffer)); memcpy(PyBytes_AS_STRING(writer->obj), writer->small_buffer, sizeof(writer->small_buffer)); } } return 0; } static inline void* PyBytesWriter_GetData(PyBytesWriter *writer) { if (writer->obj == NULL) { return writer->small_buffer; } else { return PyBytes_AS_STRING(writer->obj); } } static inline Py_ssize_t PyBytesWriter_GetSize(PyBytesWriter *writer) { return writer->size; } static inline void PyBytesWriter_Discard(PyBytesWriter *writer) { if (writer == NULL) { return; } Py_XDECREF(writer->obj); PyMem_Free(writer); } static inline PyBytesWriter* PyBytesWriter_Create(Py_ssize_t size) { if (size < 0) { PyErr_SetString(PyExc_ValueError, "size must be >= 0"); return NULL; } PyBytesWriter *writer = (PyBytesWriter*)PyMem_Malloc(sizeof(PyBytesWriter)); if (writer == NULL) { PyErr_NoMemory(); return NULL; } writer->obj = NULL; writer->size = 0; if (size >= 1) { if (_PyBytesWriter_Resize_impl(writer, size, 0) < 0) { PyBytesWriter_Discard(writer); return NULL; } writer->size = size; } return writer; } static inline PyObject* PyBytesWriter_FinishWithSize(PyBytesWriter *writer, Py_ssize_t size) { PyObject *result; if (size == 0) { result = PyBytes_FromStringAndSize("", 0); } else if (writer->obj != NULL) { if (size != PyBytes_GET_SIZE(writer->obj)) { if (_PyBytes_Resize(&writer->obj, size)) { goto error; } } result = writer->obj; writer->obj = NULL; } else { result = PyBytes_FromStringAndSize(writer->small_buffer, size); } PyBytesWriter_Discard(writer); return result; error: PyBytesWriter_Discard(writer); return NULL; } static inline PyObject* PyBytesWriter_Finish(PyBytesWriter *writer) { return PyBytesWriter_FinishWithSize(writer, writer->size); } static inline PyObject* PyBytesWriter_FinishWithPointer(PyBytesWriter *writer, void *buf) { Py_ssize_t size = (char*)buf - (char*)PyBytesWriter_GetData(writer); if (size < 0 || size > _PyBytesWriter_GetAllocated(writer)) { PyBytesWriter_Discard(writer); PyErr_SetString(PyExc_ValueError, "invalid end pointer"); return NULL; } return PyBytesWriter_FinishWithSize(writer, size); } static inline int PyBytesWriter_Resize(PyBytesWriter *writer, Py_ssize_t size) { if (size < 0) { PyErr_SetString(PyExc_ValueError, "size must be >= 0"); return -1; } if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { return -1; } writer->size = size; return 0; } static inline int PyBytesWriter_Grow(PyBytesWriter *writer, Py_ssize_t size) { if (size < 0 && writer->size + size < 0) { PyErr_SetString(PyExc_ValueError, "invalid size"); return -1; } if (size > PY_SSIZE_T_MAX - writer->size) { PyErr_NoMemory(); return -1; } size = writer->size + size; if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { return -1; } writer->size = size; return 0; } static inline void* PyBytesWriter_GrowAndUpdatePointer(PyBytesWriter *writer, Py_ssize_t size, void *buf) { Py_ssize_t pos = (char*)buf - (char*)PyBytesWriter_GetData(writer); if (PyBytesWriter_Grow(writer, size) < 0) { return NULL; } return (char*)PyBytesWriter_GetData(writer) + pos; } static inline int PyBytesWriter_WriteBytes(PyBytesWriter *writer, const void *bytes, Py_ssize_t size) { if (size < 0) { size_t len = strlen((const char*)bytes); if (len > (size_t)PY_SSIZE_T_MAX) { PyErr_NoMemory(); return -1; } size = (Py_ssize_t)len; } Py_ssize_t pos = writer->size; if (PyBytesWriter_Grow(writer, size) < 0) { return -1; } char *buf = (char*)PyBytesWriter_GetData(writer); memcpy(buf + pos, bytes, (size_t)size); return 0; } static inline int PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) Py_GCC_ATTRIBUTE((format(printf, 2, 3))); static inline int PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) { va_list vargs; va_start(vargs, format); PyObject *str = PyBytes_FromFormatV(format, vargs); va_end(vargs); if (str == NULL) { return -1; } int res = PyBytesWriter_WriteBytes(writer, PyBytes_AS_STRING(str), PyBytes_GET_SIZE(str)); Py_DECREF(str); return res; } #endif // PY_VERSION_HEX < 0x030F00A1 #ifdef __cplusplus } #endif #endif // PYTHONCAPI_COMPAT ================================================ FILE: asyncpg/protocol/record/pythoncapi_compat_extras.h ================================================ #ifndef PYTHONCAPI_COMPAT_EXTRAS #define PYTHONCAPI_COMPAT_EXTRAS #ifdef __cplusplus extern "C" { #endif #include // Python 3.11.0a6 added PyType_GetModuleByDef() to Python.h #if PY_VERSION_HEX < 0x030b00A6 PyObject * PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def) { assert(PyType_Check(type)); if (!PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE)) { // type_ready_mro() ensures that no heap type is // contained in a static type MRO. goto error; } else { PyHeapTypeObject *ht = (PyHeapTypeObject*)type; PyObject *module = ht->ht_module; if (module && PyModule_GetDef(module) == def) { return module; } } PyObject *res = NULL; PyObject *mro = type->tp_mro; // The type must be ready assert(mro != NULL); assert(PyTuple_Check(mro)); // mro_invoke() ensures that the type MRO cannot be empty. assert(PyTuple_GET_SIZE(mro) >= 1); // Also, the first item in the MRO is the type itself, which // we already checked above. We skip it in the loop. assert(PyTuple_GET_ITEM(mro, 0) == (PyObject *)type); Py_ssize_t n = PyTuple_GET_SIZE(mro); for (Py_ssize_t i = 1; i < n; i++) { PyObject *super = PyTuple_GET_ITEM(mro, i); if (!PyType_HasFeature((PyTypeObject *)super, Py_TPFLAGS_HEAPTYPE)) { // Static types in the MRO need to be skipped continue; } PyHeapTypeObject *ht = (PyHeapTypeObject*)super; PyObject *module = ht->ht_module; if (module && PyModule_GetDef(module) == def) { res = module; break; } } if (res != NULL) { return res; } error: PyErr_Format( PyExc_TypeError, "PyType_GetModuleByDef: No superclass of '%s' has the given module", type->tp_name); return NULL; } #endif #ifdef __cplusplus } #endif #endif // PYTHONCAPI_COMPAT_EXTRAS ================================================ FILE: asyncpg/protocol/record/recordobj.c ================================================ /* Big parts of this file are copied (with modifications) from CPython/Objects/tupleobject.c. Portions Copyright (c) PSF (and other CPython copyright holders). Portions Copyright (c) 2016-present MagicStack Inc. License: PSFL v2; see CPython/LICENSE for details. */ #include #include #include "pythoncapi_compat.h" #include "pythoncapi_compat_extras.h" #include "recordobj.h" #ifndef _PyCFunction_CAST #define _PyCFunction_CAST(func) ((PyCFunction)(void (*)(void))(func)) #endif static size_t ApgRecord_MAXSIZE = (((size_t)PY_SSIZE_T_MAX - sizeof(ApgRecordObject) - sizeof(PyObject *)) / sizeof(PyObject *)); /* Largest record to save on free list */ #define ApgRecord_MAXSAVESIZE 20 /* Maximum number of records of each size to save */ #define ApgRecord_MAXFREELIST 2000 typedef struct { ApgRecordObject *freelist[ApgRecord_MAXSAVESIZE]; int numfree[ApgRecord_MAXSAVESIZE]; } record_freelist_state; typedef struct { PyTypeObject *ApgRecord_Type; PyTypeObject *ApgRecordDesc_Type; PyTypeObject *ApgRecordIter_Type; PyTypeObject *ApgRecordItems_Type; Py_tss_t freelist_key; // TSS key for per-thread record_freelist_state } record_module_state; static inline record_module_state * get_module_state(PyObject *module) { void *state = PyModule_GetState(module); if (state == NULL) { PyErr_SetString(PyExc_SystemError, "failed to get record module state"); return NULL; } return (record_module_state *)state; } static inline record_module_state * get_module_state_from_type(PyTypeObject *type) { void *state = PyType_GetModuleState(type); if (state != NULL) { return (record_module_state *)state; } PyErr_Format(PyExc_SystemError, "could not get record module state from '%.100s'", type->tp_name); return NULL; } static struct PyModuleDef _recordmodule; static inline record_module_state * find_module_state_by_def(PyTypeObject *type) { PyObject *mod = PyType_GetModuleByDef(type, &_recordmodule); if (mod == NULL) return NULL; return get_module_state(mod); } static inline record_freelist_state * get_freelist_state(record_module_state *state) { record_freelist_state *freelist; freelist = (record_freelist_state *)PyThread_tss_get(&state->freelist_key); if (freelist == NULL) { freelist = (record_freelist_state *)PyMem_Calloc( 1, sizeof(record_freelist_state)); if (freelist == NULL) { PyErr_NoMemory(); return NULL; } if (PyThread_tss_set(&state->freelist_key, (void *)freelist) != 0) { PyMem_Free(freelist); PyErr_SetString( PyExc_SystemError, "failed to set thread-specific data"); return NULL; } } return freelist; } PyObject * make_record(PyTypeObject *type, PyObject *desc, Py_ssize_t size, record_module_state *state) { ApgRecordObject *o; Py_ssize_t i; int need_gc_track = 0; if (size < 0 || desc == NULL || Py_TYPE(desc) != state->ApgRecordDesc_Type) { PyErr_BadInternalCall(); return NULL; } if (type == state->ApgRecord_Type) { record_freelist_state *freelist = NULL; if (size < ApgRecord_MAXSAVESIZE) { freelist = get_freelist_state(state); if (freelist != NULL && freelist->freelist[size] != NULL) { o = freelist->freelist[size]; freelist->freelist[size] = (ApgRecordObject *)o->ob_item[0]; freelist->numfree[size]--; _Py_NewReference((PyObject *)o); } else { freelist = NULL; } } if (freelist == NULL) { if ((size_t)size > ApgRecord_MAXSIZE) { return PyErr_NoMemory(); } o = PyObject_GC_NewVar(ApgRecordObject, state->ApgRecord_Type, size); if (o == NULL) { return NULL; } } need_gc_track = 1; } else { assert(PyType_IsSubtype(type, state->ApgRecord_Type)); if ((size_t)size > ApgRecord_MAXSIZE) { return PyErr_NoMemory(); } o = (ApgRecordObject *)type->tp_alloc(type, size); if (!PyObject_GC_IsTracked((PyObject *)o)) { PyErr_SetString(PyExc_TypeError, "record subclass is not tracked by GC"); return NULL; } } for (i = 0; i < size; i++) { o->ob_item[i] = NULL; } Py_INCREF(desc); o->desc = (ApgRecordDescObject *)desc; o->self_hash = -1; if (need_gc_track) { PyObject_GC_Track(o); } return (PyObject *)o; } static void record_dealloc(PyObject *self) { ApgRecordObject *o = (ApgRecordObject *)self; Py_ssize_t i; Py_ssize_t len = Py_SIZE(o); PyTypeObject *tp = Py_TYPE(o); record_module_state *state; int skip_dealloc = 0; state = find_module_state_by_def(tp); if (state == NULL) { return; } PyObject_GC_UnTrack(o); o->self_hash = -1; Py_CLEAR(o->desc); Py_TRASHCAN_BEGIN(o, record_dealloc) i = len; while (--i >= 0) { Py_XDECREF(o->ob_item[i]); } if (len < ApgRecord_MAXSAVESIZE && tp == state->ApgRecord_Type) { record_freelist_state *freelist = get_freelist_state(state); if (freelist != NULL && freelist->numfree[len] < ApgRecord_MAXFREELIST) { o->ob_item[0] = (PyObject *)freelist->freelist[len]; freelist->numfree[len]++; freelist->freelist[len] = o; skip_dealloc = 1; } } if (!skip_dealloc) { tp->tp_free(self); Py_DECREF(tp); } Py_TRASHCAN_END } static int record_traverse(PyObject *self, visitproc visit, void *arg) { ApgRecordObject *o = (ApgRecordObject *)self; for (Py_ssize_t i = Py_SIZE(o); --i >= 0;) { Py_VISIT(o->ob_item[i]); } return 0; } /* Below are the official constants from the xxHash specification. Optimizing compilers should emit a single "rotate" instruction for the _PyTuple_HASH_XXROTATE() expansion. If that doesn't happen for some important platform, the macro could be changed to expand to a platform-specific rotate spelling instead. */ #if SIZEOF_PY_UHASH_T > 4 #define _ApgRecord_HASH_XXPRIME_1 ((Py_uhash_t)11400714785074694791ULL) #define _ApgRecord_HASH_XXPRIME_2 ((Py_uhash_t)14029467366897019727ULL) #define _ApgRecord_HASH_XXPRIME_5 ((Py_uhash_t)2870177450012600261ULL) #define _ApgRecord_HASH_XXROTATE(x) ((x << 31) | (x >> 33)) /* Rotate left 31 bits */ #else #define _ApgRecord_HASH_XXPRIME_1 ((Py_uhash_t)2654435761UL) #define _ApgRecord_HASH_XXPRIME_2 ((Py_uhash_t)2246822519UL) #define _ApgRecord_HASH_XXPRIME_5 ((Py_uhash_t)374761393UL) #define _ApgRecord_HASH_XXROTATE(x) ((x << 13) | (x >> 19)) /* Rotate left 13 bits */ #endif static Py_hash_t record_hash(PyObject *op) { ApgRecordObject *v = (ApgRecordObject *)op; Py_uhash_t acc; Py_ssize_t len = Py_SIZE(v); PyObject **item = v->ob_item; acc = _ApgRecord_HASH_XXPRIME_5; for (Py_ssize_t i = 0; i < len; i++) { Py_uhash_t lane = (Py_uhash_t)PyObject_Hash(item[i]); if (lane == (Py_uhash_t)-1) { return -1; } acc += lane * _ApgRecord_HASH_XXPRIME_2; acc = _ApgRecord_HASH_XXROTATE(acc); acc *= _ApgRecord_HASH_XXPRIME_1; } /* Add input length, mangled to keep the historical value of hash(()). */ acc += (Py_uhash_t)len ^ (_ApgRecord_HASH_XXPRIME_5 ^ 3527539UL); if (acc == (Py_uhash_t)-1) { acc = 1546275796; } return (Py_hash_t)acc; } static Py_ssize_t record_length(PyObject *self) { ApgRecordObject *a = (ApgRecordObject *)self; return Py_SIZE(a); } static int record_contains(PyObject *self, PyObject *el) { ApgRecordObject *a = (ApgRecordObject *)self; if (a->desc == NULL || a->desc->keys == NULL) { return 0; } return PySequence_Contains(a->desc->keys, el); } static PyObject * record_item(ApgRecordObject *op, Py_ssize_t i) { ApgRecordObject *a = (ApgRecordObject *)op; if (i < 0 || i >= Py_SIZE(a)) { PyErr_SetString(PyExc_IndexError, "record index out of range"); return NULL; } return Py_NewRef(a->ob_item[i]); } static PyObject * record_richcompare(PyObject *v, PyObject *w, int op) { Py_ssize_t i; Py_ssize_t vlen, wlen; int v_is_tuple = 0; int w_is_tuple = 0; int v_is_record = 0; int w_is_record = 0; int comp; PyTypeObject *v_type = Py_TYPE(v); PyTypeObject *w_type = Py_TYPE(w); record_module_state *state; state = find_module_state_by_def(v_type); if (state == NULL) { PyErr_Clear(); state = find_module_state_by_def(w_type); } if (PyTuple_Check(v)) { v_is_tuple = 1; } else if (v_type == state->ApgRecord_Type) { v_is_record = 1; } else if (!PyObject_TypeCheck(v, state->ApgRecord_Type)) { Py_RETURN_NOTIMPLEMENTED; } if (PyTuple_Check(w)) { w_is_tuple = 1; } else if (w_type == state->ApgRecord_Type) { w_is_record = 1; } else if (!PyObject_TypeCheck(w, state->ApgRecord_Type)) { Py_RETURN_NOTIMPLEMENTED; } #define V_ITEM(i) \ (v_is_tuple ? PyTuple_GET_ITEM(v, i) \ : (v_is_record ? ApgRecord_GET_ITEM(v, i) : PySequence_GetItem(v, i))) #define W_ITEM(i) \ (w_is_tuple ? PyTuple_GET_ITEM(w, i) \ : (w_is_record ? ApgRecord_GET_ITEM(w, i) : PySequence_GetItem(w, i))) vlen = Py_SIZE(v); wlen = Py_SIZE(w); if (op == Py_EQ && vlen != wlen) { /* Checking if v == w, but len(v) != len(w): return False */ Py_RETURN_FALSE; } if (op == Py_NE && vlen != wlen) { /* Checking if v != w, and len(v) != len(w): return True */ Py_RETURN_TRUE; } /* Search for the first index where items are different. * Note that because tuples are immutable, it's safe to reuse * vlen and wlen across the comparison calls. */ for (i = 0; i < vlen && i < wlen; i++) { comp = PyObject_RichCompareBool(V_ITEM(i), W_ITEM(i), Py_EQ); if (comp < 0) { return NULL; } if (!comp) { break; } } if (i >= vlen || i >= wlen) { /* No more items to compare -- compare sizes */ int cmp; switch (op) { case Py_LT: cmp = vlen < wlen; break; case Py_LE: cmp = vlen <= wlen; break; case Py_EQ: cmp = vlen == wlen; break; case Py_NE: cmp = vlen != wlen; break; case Py_GT: cmp = vlen > wlen; break; case Py_GE: cmp = vlen >= wlen; break; default: Py_UNREACHABLE(); } if (cmp) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } /* We have an item that differs -- shortcuts for EQ/NE */ if (op == Py_EQ) { Py_RETURN_FALSE; } if (op == Py_NE) { Py_RETURN_TRUE; } /* Compare the final item again using the proper operator */ return PyObject_RichCompare(V_ITEM(i), W_ITEM(i), op); #undef V_ITEM #undef W_ITEM } typedef enum item_by_name_result { APG_ITEM_FOUND = 0, APG_ERROR = -1, APG_ITEM_NOT_FOUND = -2 } item_by_name_result_t; /* Lookup a record value by its name. Return 0 on success, -2 if the * value was not found (with KeyError set), and -1 on all other errors. */ static item_by_name_result_t record_item_by_name(ApgRecordObject *o, PyObject *item, PyObject **result) { PyObject *mapped; PyObject *val; Py_ssize_t i; mapped = PyObject_GetItem(o->desc->mapping, item); if (mapped == NULL) { goto noitem; } if (!PyIndex_Check(mapped)) { Py_DECREF(mapped); goto error; } i = PyNumber_AsSsize_t(mapped, PyExc_IndexError); Py_DECREF(mapped); if (i < 0) { if (PyErr_Occurred()) PyErr_Clear(); goto error; } val = record_item(o, i); if (val == NULL) { PyErr_Clear(); goto error; } *result = val; return APG_ITEM_FOUND; noitem: PyErr_SetObject(PyExc_KeyError, item); return APG_ITEM_NOT_FOUND; error: PyErr_SetString(PyExc_RuntimeError, "invalid record descriptor"); return APG_ERROR; } static PyObject * record_subscript(PyObject *op, PyObject *item) { ApgRecordObject *self = (ApgRecordObject *)op; if (PyIndex_Check(item)) { Py_ssize_t i = PyNumber_AsSsize_t(item, PyExc_IndexError); if (i == -1 && PyErr_Occurred()) return NULL; if (i < 0) { i += Py_SIZE(self); } return record_item(self, i); } else if (PySlice_Check(item)) { Py_ssize_t start, stop, step, cur, slicelength, i; PyObject *it; PyObject **src, **dest; if (PySlice_Unpack(item, &start, &stop, &step) < 0) { return NULL; } slicelength = PySlice_AdjustIndices(Py_SIZE(self), &start, &stop, step); if (slicelength <= 0) { return PyTuple_New(0); } else if (start == 0 && step == 1 && slicelength == Py_SIZE(self) && PyTuple_CheckExact(self)) { return Py_NewRef(self); } else { PyTupleObject *result = (PyTupleObject *)PyTuple_New(slicelength); if (!result) return NULL; src = self->ob_item; dest = result->ob_item; for (cur = start, i = 0; i < slicelength; cur += step, i++) { it = Py_NewRef(src[cur]); dest[i] = it; } return (PyObject *)result; } } else { PyObject *result; if (record_item_by_name(self, item, &result) < 0) return NULL; else return result; } } static const char * get_typename(PyTypeObject *type) { assert(type->tp_name != NULL); const char *s = strrchr(type->tp_name, '.'); if (s == NULL) { s = type->tp_name; } else { s++; } return s; } static PyObject * record_repr(PyObject *self) { ApgRecordObject *v = (ApgRecordObject *)self; Py_ssize_t i, n; PyObject *keys_iter; PyUnicodeWriter *writer; n = Py_SIZE(v); if (n == 0) { return PyUnicode_FromFormat("<%s>", get_typename(Py_TYPE(v))); } keys_iter = PyObject_GetIter(v->desc->keys); if (keys_iter == NULL) { return NULL; } i = Py_ReprEnter((PyObject *)v); if (i != 0) { Py_DECREF(keys_iter); if (i > 0) { return PyUnicode_FromFormat("<%s ...>", get_typename(Py_TYPE(v))); } return NULL; } writer = PyUnicodeWriter_Create(12); /* */ if (PyUnicodeWriter_Format(writer, "<%s ", get_typename(Py_TYPE(v))) < 0) { goto error; } for (i = 0; i < n; ++i) { int res; PyObject *key; if (i > 0) if (PyUnicodeWriter_WriteChar(writer, ' ') < 0) goto error; key = PyIter_Next(keys_iter); if (key == NULL) { PyErr_SetString(PyExc_RuntimeError, "invalid record mapping"); goto error; } res = PyUnicodeWriter_WriteStr(writer, key); Py_DECREF(key); if (res < 0) goto error; if (PyUnicodeWriter_WriteChar(writer, '=') < 0) goto error; if (Py_EnterRecursiveCall(" while getting the repr of a record")) goto error; res = PyUnicodeWriter_WriteRepr(writer, v->ob_item[i]); Py_LeaveRecursiveCall(); if (res < 0) goto error; } if (PyUnicodeWriter_WriteChar(writer, '>') < 0) goto error; Py_DECREF(keys_iter); Py_ReprLeave((PyObject *)v); return PyUnicodeWriter_Finish(writer); error: Py_DECREF(keys_iter); PyUnicodeWriter_Discard(writer); Py_ReprLeave((PyObject *)v); return NULL; } static PyObject * record_new_iter(ApgRecordObject *, const record_module_state *); static PyObject * record_iter(PyObject *seq) { ApgRecordObject *r = (ApgRecordObject *)seq; record_module_state *state; state = find_module_state_by_def(Py_TYPE(seq)); if (state == NULL) { return NULL; } return record_new_iter(r, state); } static PyObject * record_values(PyObject *self, PyTypeObject *defcls, PyObject *const *args, size_t nargsf, PyObject *kwnames) { ApgRecordObject *r = (ApgRecordObject *)self; record_module_state *state = get_module_state_from_type(defcls); if (state == NULL) return NULL; return record_new_iter(r, state); } static PyObject * record_keys(PyObject *self, PyTypeObject *defcls, PyObject *const *args, size_t nargsf, PyObject *kwnames) { ApgRecordObject *r = (ApgRecordObject *)self; return PyObject_GetIter(r->desc->keys); } static PyObject * record_new_items_iter(ApgRecordObject *, const record_module_state *); static PyObject * record_items(PyObject *self, PyTypeObject *defcls, PyObject *const *args, size_t nargsf, PyObject *kwnames) { ApgRecordObject *r = (ApgRecordObject *)self; record_module_state *state = get_module_state_from_type(defcls); if (state == NULL) return NULL; return record_new_items_iter(r, state); } static PyObject * record_get(PyObject *self, PyTypeObject *defcls, PyObject *const *args, size_t nargsf, PyObject *kwnames) { Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); PyObject *key; PyObject *defval = Py_None; PyObject *val = NULL; int res; if (nargs == 2) { key = args[0]; defval = args[1]; } else if (nargs == 1) { key = args[0]; } else { PyErr_Format(PyExc_TypeError, "Record.get() expected 1 or 2 arguments, got %zd", nargs); } if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) { PyErr_SetString(PyExc_TypeError, "Record.get() takes no keyword arguments"); return NULL; } res = record_item_by_name((ApgRecordObject *)self, key, &val); if (res == APG_ITEM_NOT_FOUND) { PyErr_Clear(); Py_INCREF(defval); val = defval; } return val; } static PyObject * record_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { record_module_state *state; state = get_module_state_from_type(type); if (state == NULL) { return NULL; } if (type == state->ApgRecord_Type) { PyErr_Format(PyExc_TypeError, "cannot create '%.100s' instances", type->tp_name); return NULL; } /* For subclasses, use the default allocation */ return type->tp_alloc(type, 0); } static PyMethodDef record_methods[] = { {"values", _PyCFunction_CAST(record_values), METH_METHOD | METH_FASTCALL | METH_KEYWORDS}, {"keys", _PyCFunction_CAST(record_keys), METH_METHOD | METH_FASTCALL | METH_KEYWORDS}, {"items", _PyCFunction_CAST(record_items), METH_METHOD | METH_FASTCALL | METH_KEYWORDS}, {"get", _PyCFunction_CAST(record_get), METH_METHOD | METH_FASTCALL | METH_KEYWORDS}, {NULL, NULL} /* sentinel */ }; static PyType_Slot ApgRecord_TypeSlots[] = { {Py_tp_dealloc, record_dealloc}, {Py_tp_repr, record_repr}, {Py_tp_hash, record_hash}, {Py_tp_getattro, PyObject_GenericGetAttr}, {Py_tp_traverse, record_traverse}, {Py_tp_richcompare, record_richcompare}, {Py_tp_iter, record_iter}, {Py_tp_methods, record_methods}, {Py_tp_new, record_new}, {Py_tp_free, PyObject_GC_Del}, {Py_sq_length, record_length}, {Py_sq_item, record_item}, {Py_sq_contains, record_contains}, {Py_mp_length, record_length}, {Py_mp_subscript, record_subscript}, {0, NULL}, }; #ifndef Py_TPFLAGS_IMMUTABLETYPE #define Py_TPFLAGS_IMMUTABLETYPE 0 #endif static PyType_Spec ApgRecord_TypeSpec = { .name = "asyncpg.protocol.record.Record", .basicsize = sizeof(ApgRecordObject) - sizeof(PyObject *), .itemsize = sizeof(PyObject *), .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_IMMUTABLETYPE), .slots = ApgRecord_TypeSlots, }; /* Record Iterator */ typedef struct { PyObject_HEAD Py_ssize_t it_index; ApgRecordObject *it_seq; /* Set to NULL when iterator is exhausted */ } ApgRecordIterObject; static void record_iter_dealloc(ApgRecordIterObject *it) { PyTypeObject *tp = Py_TYPE(it); PyObject_GC_UnTrack(it); Py_CLEAR(it->it_seq); PyObject_GC_Del(it); Py_DECREF(tp); } static int record_iter_traverse(ApgRecordIterObject *it, visitproc visit, void *arg) { Py_VISIT(it->it_seq); return 0; } static PyObject * record_iter_next(ApgRecordIterObject *it) { ApgRecordObject *seq; PyObject *item; assert(it != NULL); seq = it->it_seq; if (seq == NULL) return NULL; if (it->it_index < Py_SIZE(seq)) { item = ApgRecord_GET_ITEM(seq, it->it_index); ++it->it_index; Py_INCREF(item); return item; } it->it_seq = NULL; Py_DECREF(seq); return NULL; } static PyObject * record_iter_len(ApgRecordIterObject *it) { Py_ssize_t len = 0; if (it->it_seq) { len = Py_SIZE(it->it_seq) - it->it_index; } return PyLong_FromSsize_t(len); } PyDoc_STRVAR(record_iter_len_doc, "Private method returning an estimate of len(list(it))."); static PyMethodDef record_iter_methods[] = { {"__length_hint__", (PyCFunction)record_iter_len, METH_NOARGS, record_iter_len_doc}, {NULL, NULL} /* sentinel */ }; static PyType_Slot ApgRecordIter_TypeSlots[] = { {Py_tp_dealloc, (destructor)record_iter_dealloc}, {Py_tp_getattro, PyObject_GenericGetAttr}, {Py_tp_traverse, (traverseproc)record_iter_traverse}, {Py_tp_iter, PyObject_SelfIter}, {Py_tp_iternext, (iternextfunc)record_iter_next}, {Py_tp_methods, record_iter_methods}, {0, NULL}, }; static PyType_Spec ApgRecordIter_TypeSpec = { .name = "asyncpg.protocol.record.RecordIterator", .basicsize = sizeof(ApgRecordIterObject), .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, .slots = ApgRecordIter_TypeSlots, }; static PyObject * record_new_iter(ApgRecordObject *r, const record_module_state *state) { ApgRecordIterObject *it; it = PyObject_GC_New(ApgRecordIterObject, state->ApgRecordIter_Type); if (it == NULL) return NULL; it->it_index = 0; Py_INCREF(r); it->it_seq = r; PyObject_GC_Track(it); return (PyObject *)it; } /* Record Items Iterator */ typedef struct { PyObject_HEAD Py_ssize_t it_index; PyObject *it_key_iter; ApgRecordObject *it_seq; /* Set to NULL when iterator is exhausted */ } ApgRecordItemsObject; static void record_items_dealloc(ApgRecordItemsObject *it) { PyTypeObject *tp = Py_TYPE(it); PyObject_GC_UnTrack(it); Py_CLEAR(it->it_key_iter); Py_CLEAR(it->it_seq); PyObject_GC_Del(it); Py_DECREF(tp); } static int record_items_traverse(ApgRecordItemsObject *it, visitproc visit, void *arg) { Py_VISIT(it->it_key_iter); Py_VISIT(it->it_seq); return 0; } static PyObject * record_items_next(ApgRecordItemsObject *it) { ApgRecordObject *seq; PyObject *key; PyObject *val; PyObject *tup; assert(it != NULL); seq = it->it_seq; if (seq == NULL) { return NULL; } assert(it->it_key_iter != NULL); key = PyIter_Next(it->it_key_iter); if (key == NULL) { /* likely it_key_iter had less items than seq has values */ goto exhausted; } if (it->it_index < Py_SIZE(seq)) { val = ApgRecord_GET_ITEM(seq, it->it_index); ++it->it_index; Py_INCREF(val); } else { /* it_key_iter had more items than seq has values */ Py_DECREF(key); goto exhausted; } tup = PyTuple_New(2); if (tup == NULL) { Py_DECREF(val); Py_DECREF(key); goto exhausted; } PyTuple_SET_ITEM(tup, 0, key); PyTuple_SET_ITEM(tup, 1, val); return tup; exhausted: Py_CLEAR(it->it_key_iter); Py_CLEAR(it->it_seq); return NULL; } static PyObject * record_items_len(ApgRecordItemsObject *it) { Py_ssize_t len = 0; if (it->it_seq) { len = Py_SIZE(it->it_seq) - it->it_index; } return PyLong_FromSsize_t(len); } PyDoc_STRVAR(record_items_len_doc, "Private method returning an estimate of len(list(it()))."); static PyMethodDef record_items_methods[] = { {"__length_hint__", (PyCFunction)record_items_len, METH_NOARGS, record_items_len_doc}, {NULL, NULL} /* sentinel */ }; static PyType_Slot ApgRecordItems_TypeSlots[] = { {Py_tp_dealloc, (destructor)record_items_dealloc}, {Py_tp_getattro, PyObject_GenericGetAttr}, {Py_tp_traverse, (traverseproc)record_items_traverse}, {Py_tp_iter, PyObject_SelfIter}, {Py_tp_iternext, (iternextfunc)record_items_next}, {Py_tp_methods, record_items_methods}, {0, NULL}, }; static PyType_Spec ApgRecordItems_TypeSpec = { .name = "asyncpg.protocol.record.RecordItemsIterator", .basicsize = sizeof(ApgRecordItemsObject), .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, .slots = ApgRecordItems_TypeSlots, }; static PyObject * record_new_items_iter(ApgRecordObject *r, const record_module_state *state) { ApgRecordItemsObject *it; PyObject *key_iter; key_iter = PyObject_GetIter(r->desc->keys); if (key_iter == NULL) return NULL; it = PyObject_GC_New(ApgRecordItemsObject, state->ApgRecordItems_Type); if (it == NULL) { Py_DECREF(key_iter); return NULL; } it->it_key_iter = key_iter; it->it_index = 0; Py_INCREF(r); it->it_seq = r; PyObject_GC_Track(it); return (PyObject *)it; } /* ----------------- */ static void record_desc_dealloc(ApgRecordDescObject *o) { PyTypeObject *tp = Py_TYPE(o); PyObject_GC_UnTrack(o); Py_CLEAR(o->mapping); Py_CLEAR(o->keys); PyObject_GC_Del(o); Py_DECREF(tp); } static int record_desc_traverse(ApgRecordDescObject *o, visitproc visit, void *arg) { Py_VISIT(o->mapping); Py_VISIT(o->keys); return 0; } static PyObject * record_desc_vectorcall(PyObject *type, PyObject *const *args, size_t nargsf, PyObject *kwnames) { PyObject *mapping; PyObject *keys; ApgRecordDescObject *o; Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) { PyErr_SetString(PyExc_TypeError, "RecordDescriptor() takes no keyword arguments"); return NULL; } if (nargs != 2) { PyErr_Format(PyExc_TypeError, "RecordDescriptor() takes exactly 2 arguments (%zd given)", nargs); return NULL; } mapping = args[0]; keys = args[1]; if (!PyTuple_CheckExact(keys)) { PyErr_SetString(PyExc_TypeError, "keys must be a tuple"); return NULL; } o = PyObject_GC_New(ApgRecordDescObject, (PyTypeObject *)type); if (o == NULL) { return NULL; } Py_INCREF(mapping); o->mapping = mapping; Py_INCREF(keys); o->keys = keys; PyObject_GC_Track(o); return (PyObject *)o; } /* Fallback wrapper for when there is no vectorcall support */ static PyObject * record_desc_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { PyObject *const *args_array; size_t nargsf; PyObject *kwnames = NULL; if (kwargs != NULL && PyDict_GET_SIZE(kwargs) != 0) { PyErr_SetString(PyExc_TypeError, "RecordDescriptor() takes no keyword arguments"); return NULL; } if (!PyTuple_Check(args)) { PyErr_SetString(PyExc_TypeError, "args must be a tuple"); return NULL; } nargsf = (size_t)PyTuple_GET_SIZE(args); args_array = &PyTuple_GET_ITEM(args, 0); return record_desc_vectorcall((PyObject *)type, args_array, nargsf, kwnames); } static PyObject * record_desc_make_record(PyObject *desc, PyTypeObject *desc_type, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PyObject *type_obj; Py_ssize_t size; record_module_state *state = get_module_state_from_type(desc_type); if (state == NULL) { return NULL; } if (nargs != 2) { PyErr_Format(PyExc_TypeError, "RecordDescriptor.make_record() takes exactly 2 arguments (%zd given)", nargs); return NULL; } if (kwnames != NULL && PyTuple_GET_SIZE(kwnames) != 0) { PyErr_SetString(PyExc_TypeError, "RecordDescriptor.make_record() takes no keyword arguments"); return NULL; } type_obj = args[0]; size = PyLong_AsSsize_t(args[1]); if (size == -1 && PyErr_Occurred()) { return NULL; } if (!PyType_Check(type_obj)) { PyErr_SetString(PyExc_TypeError, "RecordDescriptor.make_record(): first argument must be a type"); return NULL; } return make_record((PyTypeObject *)type_obj, desc, size, state); } static PyMethodDef record_desc_methods[] = { {"make_record", _PyCFunction_CAST(record_desc_make_record), METH_FASTCALL | METH_METHOD | METH_KEYWORDS}, {NULL, NULL} /* sentinel */ }; static PyType_Slot ApgRecordDesc_TypeSlots[] = { #ifdef Py_tp_vectorcall {Py_tp_vectorcall, (vectorcallfunc)record_desc_vectorcall}, #endif {Py_tp_new, (newfunc)record_desc_new}, {Py_tp_dealloc, (destructor)record_desc_dealloc}, {Py_tp_getattro, PyObject_GenericGetAttr}, {Py_tp_traverse, (traverseproc)record_desc_traverse}, {Py_tp_methods, record_desc_methods}, {0, NULL}, }; static PyType_Spec ApgRecordDesc_TypeSpec = { .name = "asyncpg.protocol.record.RecordDescriptor", .basicsize = sizeof(ApgRecordDescObject), .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_IMMUTABLETYPE, .slots = ApgRecordDesc_TypeSlots, }; /* * Module init */ static PyMethodDef record_module_methods[] = {{NULL, NULL, 0, NULL}}; static int record_module_exec(PyObject *module) { record_module_state *state = get_module_state(module); if (state == NULL) { return -1; } if (PyThread_tss_create(&state->freelist_key) != 0) { PyErr_SetString( PyExc_SystemError, "failed to create TSS key for record freelist"); return -1; } #define CREATE_TYPE(m, tp, spec) \ do { \ tp = (PyTypeObject *)PyType_FromModuleAndSpec(m, spec, NULL); \ if (tp == NULL) \ goto error; \ if (PyModule_AddType(m, tp) < 0) \ goto error; \ } while (0) CREATE_TYPE(module, state->ApgRecord_Type, &ApgRecord_TypeSpec); CREATE_TYPE(module, state->ApgRecordDesc_Type, &ApgRecordDesc_TypeSpec); CREATE_TYPE(module, state->ApgRecordIter_Type, &ApgRecordIter_TypeSpec); CREATE_TYPE(module, state->ApgRecordItems_Type, &ApgRecordItems_TypeSpec); #undef CREATE_TYPE return 0; error: Py_CLEAR(state->ApgRecord_Type); Py_CLEAR(state->ApgRecordDesc_Type); Py_CLEAR(state->ApgRecordIter_Type); Py_CLEAR(state->ApgRecordItems_Type); return -1; } static int record_module_traverse(PyObject *module, visitproc visit, void *arg) { record_module_state *state = get_module_state(module); if (state == NULL) { return 0; } Py_VISIT(state->ApgRecord_Type); Py_VISIT(state->ApgRecordDesc_Type); Py_VISIT(state->ApgRecordIter_Type); Py_VISIT(state->ApgRecordItems_Type); return 0; } static int record_module_clear(PyObject *module) { record_module_state *state = get_module_state(module); if (state == NULL) { return 0; } if (PyThread_tss_is_created(&state->freelist_key)) { record_freelist_state *freelist = (record_freelist_state *)PyThread_tss_get(&state->freelist_key); if (freelist != NULL) { for (int i = 0; i < ApgRecord_MAXSAVESIZE; i++) { ApgRecordObject *op = freelist->freelist[i]; while (op != NULL) { ApgRecordObject *next = (ApgRecordObject *)(op->ob_item[0]); PyObject_GC_Del(op); op = next; } freelist->freelist[i] = NULL; freelist->numfree[i] = 0; } PyMem_Free(freelist); PyThread_tss_set(&state->freelist_key, NULL); } PyThread_tss_delete(&state->freelist_key); } Py_CLEAR(state->ApgRecord_Type); Py_CLEAR(state->ApgRecordDesc_Type); Py_CLEAR(state->ApgRecordIter_Type); Py_CLEAR(state->ApgRecordItems_Type); return 0; } static void record_module_free(void *module) { record_module_clear((PyObject *)module); } static PyModuleDef_Slot record_module_slots[] = { {Py_mod_exec, record_module_exec}, #ifdef Py_mod_multiple_interpreters {Py_mod_multiple_interpreters, Py_MOD_PER_INTERPRETER_GIL_SUPPORTED}, #endif #ifdef Py_mod_gil {Py_mod_gil, Py_MOD_GIL_NOT_USED}, #endif {0, NULL}, }; static struct PyModuleDef _recordmodule = { PyModuleDef_HEAD_INIT, .m_name = "asyncpg.protocol.record", .m_size = sizeof(record_module_state), .m_methods = record_module_methods, .m_slots = record_module_slots, .m_traverse = record_module_traverse, .m_clear = record_module_clear, .m_free = record_module_free, }; PyMODINIT_FUNC PyInit_record(void) { return PyModuleDef_Init(&_recordmodule); } ================================================ FILE: asyncpg/protocol/record/recordobj.h ================================================ #ifndef APG_RECORDOBJ_H #define APG_RECORDOBJ_H #include typedef struct { PyObject_HEAD PyObject *mapping; PyObject *keys; } ApgRecordDescObject; typedef struct { PyObject_VAR_HEAD Py_hash_t self_hash; ApgRecordDescObject *desc; PyObject *ob_item[1]; /* ob_item contains space for 'ob_size' elements. * Items must normally not be NULL, except during construction when * the record is not yet visible outside the function that builds it. */ } ApgRecordObject; #define ApgRecord_SET_ITEM(op, i, v) \ (((ApgRecordObject *)(op))->ob_item[i] = v) #define ApgRecord_GET_ITEM(op, i) \ (((ApgRecordObject *)(op))->ob_item[i]) #endif ================================================ FILE: asyncpg/protocol/record.pyi ================================================ from typing import ( Any, TypeVar, overload, ) from collections.abc import Iterator _T = TypeVar("_T") class Record: @overload def get(self, key: str) -> Any | None: ... @overload def get(self, key: str, default: _T) -> Any | _T: ... def items(self) -> Iterator[tuple[str, Any]]: ... def keys(self) -> Iterator[str]: ... def values(self) -> Iterator[Any]: ... @overload def __getitem__(self, index: str) -> Any: ... @overload def __getitem__(self, index: int) -> Any: ... @overload def __getitem__(self, index: slice) -> tuple[Any, ...]: ... def __iter__(self) -> Iterator[Any]: ... def __contains__(self, x: object) -> bool: ... def __len__(self) -> int: ... ================================================ FILE: asyncpg/protocol/recordcapi.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 cimport cpython cdef extern from "record/recordobj.h": void ApgRecord_SET_ITEM(object, int, object) object RecordDescriptor(object, object) ================================================ FILE: asyncpg/protocol/scram.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 cdef class SCRAMAuthentication: cdef: readonly bytes authentication_method readonly bytes authorization_message readonly bytes client_channel_binding readonly bytes client_first_message_bare readonly bytes client_nonce readonly bytes client_proof readonly bytes password_salt readonly int password_iterations readonly bytes server_first_message # server_key is an instance of hmac.HAMC readonly object server_key readonly bytes server_nonce cdef create_client_first_message(self, str username) cdef create_client_final_message(self, str password) cdef parse_server_first_message(self, bytes server_response) cdef verify_server_final_message(self, bytes server_final_message) cdef _bytes_xor(self, bytes a, bytes b) cdef _generate_client_nonce(self, int num_bytes) cdef _generate_client_proof(self, str password) cdef _generate_salted_password(self, str password, bytes salt, int iterations) cdef _normalize_password(self, str original_password) ================================================ FILE: asyncpg/protocol/scram.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import base64 import hashlib import hmac import re import secrets import stringprep import unicodedata @cython.final cdef class SCRAMAuthentication: """Contains the protocol for generating and a SCRAM hashed password. Since PostgreSQL 10, the option to hash passwords using the SCRAM-SHA-256 method was added. This module follows the defined protocol, which can be referenced from here: https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256 libpq references the following RFCs that it uses for implementation: * RFC 5802 * RFC 5803 * RFC 7677 The protocol works as such: - A client connets to the server. The server requests the client to begin SASL authentication using SCRAM and presents a client with the methods it supports. At present, those are SCRAM-SHA-256, and, on servers that are built with OpenSSL and are PG11+, SCRAM-SHA-256-PLUS (which supports channel binding, more on that below) - The client sends a "first message" to the server, where it chooses which method to authenticate with, and sends, along with the method, an indication of channel binding (we disable for now), a nonce, and the username. (Technically, PostgreSQL ignores the username as it already has it from the initical connection, but we add it for completeness) - The server responds with a "first message" in which it extends the nonce, as well as a password salt and the number of iterations to hash the password with. The client validates that the new nonce contains the first part of the client's original nonce - The client generates a salted password, but does not sent this up to the server. Instead, the client follows the SCRAM algorithm (RFC5802) to generate a proof. This proof is sent aspart of a client "final message" to the server for it to validate. - The server validates the proof. If it is valid, the server sends a verification code for the client to verify that the server came to the same proof the client did. PostgreSQL immediately sends an AuthenticationOK response right after a valid negotiation. If the password the client provided was invalid, then authentication fails. (The beauty of this is that the salted password is never transmitted over the wire!) PostgreSQL 11 added support for the channel binding (i.e. SCRAM-SHA-256-PLUS) but to do some ongoing discussion, there is a conscious decision by several driver authors to not support it as of yet. As such, the channel binding parameter is hard-coded to "n" for now, but can be updated to support other channel binding methos in the future """ AUTHENTICATION_METHODS = [b"SCRAM-SHA-256"] DEFAULT_CLIENT_NONCE_BYTES = 24 DIGEST = hashlib.sha256 REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding', 'server_nonce'] REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt', 'server_first_message', 'server_nonce'] SASLPREP_PROHIBITED = ( stringprep.in_table_a1, # PostgreSQL treats this as prohibited stringprep.in_table_c12, stringprep.in_table_c21_c22, stringprep.in_table_c3, stringprep.in_table_c4, stringprep.in_table_c5, stringprep.in_table_c6, stringprep.in_table_c7, stringprep.in_table_c8, stringprep.in_table_c9, ) def __cinit__(self, bytes authentication_method): self.authentication_method = authentication_method self.authorization_message = None # channel binding is turned off for the time being self.client_channel_binding = b"n,," self.client_first_message_bare = None self.client_nonce = None self.client_proof = None self.password_salt = None # self.password_iterations = None self.server_first_message = None self.server_key = None self.server_nonce = None cdef create_client_first_message(self, str username): """Create the initial client message for SCRAM authentication""" cdef: bytes msg bytes client_first_message self.client_nonce = \ self._generate_client_nonce(self.DEFAULT_CLIENT_NONCE_BYTES) # set the client first message bare here, as it's used in a later step self.client_first_message_bare = b"n=" + username.encode("utf-8") + \ b",r=" + self.client_nonce # put together the full message here msg = bytes() msg += self.authentication_method + b"\0" client_first_message = self.client_channel_binding + \ self.client_first_message_bare msg += (len(client_first_message)).to_bytes(4, byteorder='big') + \ client_first_message return msg cdef create_client_final_message(self, str password): """Create the final client message as part of SCRAM authentication""" cdef: bytes msg if any([getattr(self, val) is None for val in self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]): raise Exception( "you need values from server to generate a client proof") # normalize the password using the SASLprep algorithm in RFC 4013 password = self._normalize_password(password) # generate the client proof self.client_proof = self._generate_client_proof(password=password) msg = bytes() msg += b"c=" + base64.b64encode(self.client_channel_binding) + \ b",r=" + self.server_nonce + \ b",p=" + base64.b64encode(self.client_proof) return msg cdef parse_server_first_message(self, bytes server_response): """Parse the response from the first message from the server""" self.server_first_message = server_response try: self.server_nonce = re.search(b'r=([^,]+),', self.server_first_message).group(1) except IndexError: raise Exception("could not get nonce") if not self.server_nonce.startswith(self.client_nonce): raise Exception("invalid nonce") try: self.password_salt = re.search(b',s=([^,]+),', self.server_first_message).group(1) except IndexError: raise Exception("could not get salt") try: self.password_iterations = int(re.search(b',i=(\d+),?', self.server_first_message).group(1)) except (IndexError, TypeError, ValueError): raise Exception("could not get iterations") cdef verify_server_final_message(self, bytes server_final_message): """Verify the final message from the server""" cdef: bytes server_signature try: server_signature = re.search(b'v=([^,]+)', server_final_message).group(1) except IndexError: raise Exception("could not get server signature") verify_server_signature = hmac.new(self.server_key.digest(), self.authorization_message, self.DIGEST) # validate the server signature against the verifier return server_signature == base64.b64encode( verify_server_signature.digest()) cdef _bytes_xor(self, bytes a, bytes b): """XOR two bytestrings together""" return bytes(a_i ^ b_i for a_i, b_i in zip(a, b)) cdef _generate_client_nonce(self, int num_bytes): cdef: bytes token token = secrets.token_bytes(num_bytes) return base64.b64encode(token) cdef _generate_client_proof(self, str password): """need to ensure a server response exists, i.e. """ cdef: bytes salted_password if any([getattr(self, val) is None for val in self.REQUIREMENTS_CLIENT_PROOF]): raise Exception( "you need values from server to generate a client proof") # generate a salt password salted_password = self._generate_salted_password(password, self.password_salt, self.password_iterations) # client key is derived from the salted password client_key = hmac.new(salted_password, b"Client Key", self.DIGEST) # this allows us to compute the stored key that is residing on the server stored_key = self.DIGEST(client_key.digest()) # as well as compute the server key self.server_key = hmac.new(salted_password, b"Server Key", self.DIGEST) # build the authorization message that will be used in the # client signature # the "c=" portion is for the channel binding, but this is not # presently implemented self.authorization_message = self.client_first_message_bare + b"," + \ self.server_first_message + b",c=" + \ base64.b64encode(self.client_channel_binding) + \ b",r=" + self.server_nonce # sign! client_signature = hmac.new(stored_key.digest(), self.authorization_message, self.DIGEST) # and the proof return self._bytes_xor(client_key.digest(), client_signature.digest()) cdef _generate_salted_password(self, str password, bytes salt, int iterations): """This follows the "Hi" algorithm specified in RFC5802""" cdef: bytes p bytes s bytes u # convert the password to a binary string - UTF8 is safe for SASL # (though there are SASLPrep rules) p = password.encode("utf8") # the salt needs to be base64 decoded -- full binary must be used s = base64.b64decode(salt) # the initial signature is the salt with a terminator of a 32-bit string # ending in 1 ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST) # grab the initial digest u = ui.digest() # for X number of iterations, recompute the HMAC signature against the # password and the latest iteration of the hash, and XOR it with the # previous version for x in range(iterations - 1): ui = hmac.new(p, ui.digest(), hashlib.sha256) # this is a fancy way of XORing two byte strings together u = self._bytes_xor(u, ui.digest()) return u cdef _normalize_password(self, str original_password): """Normalize the password using the SASLprep from RFC4013""" cdef: str normalized_password # Note: Per the PostgreSQL documentation, PostgreSWL does not require # UTF-8 to be used for the password, but will perform SASLprep on the # password regardless. # If the password is not valid UTF-8, PostgreSQL will then **not** use # SASLprep processing. # If the password fails SASLprep, the password should still be sent # See: https://www.postgresql.org/docs/current/sasl-authentication.html # and # https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/common/saslprep.c # using the `pg_saslprep` function normalized_password = original_password # if the original password is an ASCII string or fails to encode as a # UTF-8 string, then no further action is needed try: original_password.encode("ascii") except UnicodeEncodeError: pass else: return original_password # Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space # characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and # commonly mapped to nothing characters are removed # Table C.1.2 -- non-ASCII spaces # Table B.1 -- "Commonly mapped to nothing" normalized_password = u"".join( ' ' if stringprep.in_table_c12(c) else c for c in tuple(normalized_password) if not stringprep.in_table_b1(c) ) # If at this point the password is empty, PostgreSQL uses the original # password if not normalized_password: return original_password # Step 2 of SASLPrep: Normalize. Normalize the password using the # Unicode normalization algorithm to NFKC form normalized_password = unicodedata.normalize('NFKC', normalized_password) # If the password is not empty, PostgreSQL uses the original password if not normalized_password: return original_password normalized_password_tuple = tuple(normalized_password) # Step 3 of SASLPrep: Prohobited characters. If PostgreSQL detects any # of the prohibited characters in SASLPrep, it will use the original # password # We also include "unassigned code points" in the prohibited character # category as PostgreSQL does the same for c in normalized_password_tuple: if any( in_prohibited_table(c) for in_prohibited_table in self.SASLPREP_PROHIBITED ): return original_password # Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the # rules for bi-directional characters laid on in RFC3454 Sec. 6 which # are: # 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8) # 2. If a string contains a RandALCat character, it cannot containy any # LCat character # 3. If the string contains any RandALCat character, an RandALCat # character must be the first and last character of the string # RandALCat characters are found in table D.1, whereas LCat are in D.2 if any(stringprep.in_table_d1(c) for c in normalized_password_tuple): # if the first character or the last character are not in D.1, # return the original password if not (stringprep.in_table_d1(normalized_password_tuple[0]) and stringprep.in_table_d1(normalized_password_tuple[-1])): return original_password # if any characters are in D.2, use the original password if any( stringprep.in_table_d2(c) for c in normalized_password_tuple ): return original_password # return the normalized password return normalized_password ================================================ FILE: asyncpg/protocol/settings.pxd ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 cdef class ConnectionSettings(pgproto.CodecContext): cdef: str _encoding object _codec dict _settings bint _is_utf8 DataCodecConfig _data_codecs cdef add_setting(self, str name, str val) cdef is_encoding_utf8(self) cpdef get_text_codec(self) cpdef inline register_data_types(self, types) cpdef inline add_python_codec( self, typeoid, typename, typeschema, typeinfos, typekind, encoder, decoder, format) cpdef inline remove_python_codec( self, typeoid, typename, typeschema) cpdef inline clear_type_cache(self) cpdef inline set_builtin_type_codec( self, typeoid, typename, typeschema, typekind, alias_to, format) cpdef inline Codec get_data_codec( self, uint32_t oid, ServerDataFormat format=*, bint ignore_custom_codec=*) ================================================ FILE: asyncpg/protocol/settings.pyx ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from asyncpg import exceptions @cython.final cdef class ConnectionSettings(pgproto.CodecContext): def __cinit__(self): self._encoding = 'utf-8' self._is_utf8 = True self._settings = {} self._codec = codecs.lookup('utf-8') self._data_codecs = DataCodecConfig() cdef add_setting(self, str name, str val): self._settings[name] = val if name == 'client_encoding': py_enc = get_python_encoding(val) self._codec = codecs.lookup(py_enc) self._encoding = self._codec.name self._is_utf8 = self._encoding == 'utf-8' cdef is_encoding_utf8(self): return self._is_utf8 cpdef get_text_codec(self): return self._codec cpdef inline register_data_types(self, types): self._data_codecs.add_types(types) cpdef inline add_python_codec(self, typeoid, typename, typeschema, typeinfos, typekind, encoder, decoder, format): cdef: ServerDataFormat _format ClientExchangeFormat xformat if format == 'binary': _format = PG_FORMAT_BINARY xformat = PG_XFORMAT_OBJECT elif format == 'text': _format = PG_FORMAT_TEXT xformat = PG_XFORMAT_OBJECT elif format == 'tuple': _format = PG_FORMAT_ANY xformat = PG_XFORMAT_TUPLE else: raise exceptions.InterfaceError( 'invalid `format` argument, expected {}, got {!r}'.format( "'text', 'binary' or 'tuple'", format )) self._data_codecs.add_python_codec(typeoid, typename, typeschema, typekind, typeinfos, encoder, decoder, _format, xformat) cpdef inline remove_python_codec(self, typeoid, typename, typeschema): self._data_codecs.remove_python_codec(typeoid, typename, typeschema) cpdef inline clear_type_cache(self): self._data_codecs.clear_type_cache() cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, alias_to, format): cdef: ServerDataFormat _format if format is None: _format = PG_FORMAT_ANY elif format == 'binary': _format = PG_FORMAT_BINARY elif format == 'text': _format = PG_FORMAT_TEXT else: raise exceptions.InterfaceError( 'invalid `format` argument, expected {}, got {!r}'.format( "'text' or 'binary'", format )) self._data_codecs.set_builtin_type_codec(typeoid, typename, typeschema, typekind, alias_to, _format) cpdef inline Codec get_data_codec(self, uint32_t oid, ServerDataFormat format=PG_FORMAT_ANY, bint ignore_custom_codec=False): return self._data_codecs.get_codec(oid, format, ignore_custom_codec) def __getattr__(self, name): if not name.startswith('_'): try: return self._settings[name] except KeyError: raise AttributeError(name) from None return object.__getattribute__(self, name) def __repr__(self): return ''.format(self._settings) ================================================ FILE: asyncpg/serverversion.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import re import typing from .types import ServerVersion version_regex: typing.Final = re.compile( r"(Postgre[^\s]*)?\s*" r"(?P[0-9]+)\.?" r"((?P[0-9]+)\.?)?" r"(?P[0-9]+)?" r"(?P[a-z]+)?" r"(?P[0-9]+)?" ) class _VersionDict(typing.TypedDict): major: int minor: int | None micro: int | None releaselevel: str | None serial: int | None def split_server_version_string(version_string: str) -> ServerVersion: version_match = version_regex.search(version_string) if version_match is None: raise ValueError( "Unable to parse Postgres " f'version from "{version_string}"' ) version: _VersionDict = version_match.groupdict() # type: ignore[assignment] # noqa: E501 for ver_key, ver_value in version.items(): # Cast all possible versions parts to int try: version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501 except (TypeError, ValueError): pass if version["major"] < 10: return ServerVersion( version["major"], version.get("minor") or 0, version.get("micro") or 0, version.get("releaselevel") or "final", version.get("serial") or 0, ) # Since PostgreSQL 10 the versioning scheme has changed. # 10.x really means 10.0.x. While parsing 10.1 # as (10, 1) may seem less confusing, in practice most # version checks are written as version[:2], and we # want to keep that behaviour consistent, i.e not fail # a major version check due to a bugfix release. return ServerVersion( version["major"], 0, version.get("minor") or 0, version.get("releaselevel") or "final", version.get("serial") or 0, ) ================================================ FILE: asyncpg/transaction.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import enum from . import connresource from . import exceptions as apg_errors class TransactionState(enum.Enum): NEW = 0 STARTED = 1 COMMITTED = 2 ROLLEDBACK = 3 FAILED = 4 ISOLATION_LEVELS = { 'read_committed', 'read_uncommitted', 'serializable', 'repeatable_read', } ISOLATION_LEVELS_BY_VALUE = { 'read committed': 'read_committed', 'read uncommitted': 'read_uncommitted', 'serializable': 'serializable', 'repeatable read': 'repeatable_read', } class Transaction(connresource.ConnectionResource): """Represents a transaction or savepoint block. Transactions are created by calling the :meth:`Connection.transaction() ` function. """ __slots__ = ('_connection', '_isolation', '_readonly', '_deferrable', '_state', '_nested', '_id', '_managed') def __init__(self, connection, isolation, readonly, deferrable): super().__init__(connection) if isolation and isolation not in ISOLATION_LEVELS: raise ValueError( 'isolation is expected to be either of {}, ' 'got {!r}'.format(ISOLATION_LEVELS, isolation)) self._isolation = isolation self._readonly = readonly self._deferrable = deferrable self._state = TransactionState.NEW self._nested = False self._id = None self._managed = False async def __aenter__(self): if self._managed: raise apg_errors.InterfaceError( 'cannot enter context: already in an `async with` block') self._managed = True await self.start() async def __aexit__(self, extype, ex, tb): try: self._check_conn_validity('__aexit__') except apg_errors.InterfaceError: if extype is GeneratorExit: # When a PoolAcquireContext is being exited, and there # is an open transaction in an async generator that has # not been iterated fully, there is a possibility that # Pool.release() would race with this __aexit__(), since # both would be in concurrent tasks. In such case we # yield to Pool.release() to do the ROLLBACK for us. # See https://github.com/MagicStack/asyncpg/issues/232 # for an example. return else: raise try: if extype is not None: await self.__rollback() else: await self.__commit() finally: self._managed = False @connresource.guarded async def start(self): """Enter the transaction or savepoint block.""" self.__check_state_base('start') if self._state is TransactionState.STARTED: raise apg_errors.InterfaceError( 'cannot start; the transaction is already started') con = self._connection if con._top_xact is None: if con._protocol.is_in_transaction(): raise apg_errors.InterfaceError( 'cannot use Connection.transaction() in ' 'a manually started transaction') con._top_xact = self else: # Nested transaction block if self._isolation: top_xact_isolation = con._top_xact._isolation if top_xact_isolation is None: top_xact_isolation = ISOLATION_LEVELS_BY_VALUE[ await self._connection.fetchval( 'SHOW transaction_isolation;')] if self._isolation != top_xact_isolation: raise apg_errors.InterfaceError( 'nested transaction has a different isolation level: ' 'current {!r} != outer {!r}'.format( self._isolation, top_xact_isolation)) self._nested = True if self._nested: self._id = con._get_unique_id('savepoint') query = 'SAVEPOINT {};'.format(self._id) else: query = 'BEGIN' if self._isolation == 'read_committed': query += ' ISOLATION LEVEL READ COMMITTED' elif self._isolation == 'read_uncommitted': query += ' ISOLATION LEVEL READ UNCOMMITTED' elif self._isolation == 'repeatable_read': query += ' ISOLATION LEVEL REPEATABLE READ' elif self._isolation == 'serializable': query += ' ISOLATION LEVEL SERIALIZABLE' if self._readonly: query += ' READ ONLY' if self._deferrable: query += ' DEFERRABLE' query += ';' try: await self._connection.execute(query) except BaseException: self._state = TransactionState.FAILED raise else: self._state = TransactionState.STARTED def __check_state_base(self, opname): if self._state is TransactionState.COMMITTED: raise apg_errors.InterfaceError( 'cannot {}; the transaction is already committed'.format( opname)) if self._state is TransactionState.ROLLEDBACK: raise apg_errors.InterfaceError( 'cannot {}; the transaction is already rolled back'.format( opname)) if self._state is TransactionState.FAILED: raise apg_errors.InterfaceError( 'cannot {}; the transaction is in error state'.format( opname)) def __check_state(self, opname): if self._state is not TransactionState.STARTED: if self._state is TransactionState.NEW: raise apg_errors.InterfaceError( 'cannot {}; the transaction is not yet started'.format( opname)) self.__check_state_base(opname) async def __commit(self): self.__check_state('commit') if self._connection._top_xact is self: self._connection._top_xact = None if self._nested: query = 'RELEASE SAVEPOINT {};'.format(self._id) else: query = 'COMMIT;' try: await self._connection.execute(query) except BaseException: self._state = TransactionState.FAILED raise else: self._state = TransactionState.COMMITTED async def __rollback(self): self.__check_state('rollback') if self._connection._top_xact is self: self._connection._top_xact = None if self._nested: query = 'ROLLBACK TO {};'.format(self._id) else: query = 'ROLLBACK;' try: await self._connection.execute(query) except BaseException: self._state = TransactionState.FAILED raise else: self._state = TransactionState.ROLLEDBACK @connresource.guarded async def commit(self): """Exit the transaction or savepoint block and commit changes.""" if self._managed: raise apg_errors.InterfaceError( 'cannot manually commit from within an `async with` block') await self.__commit() @connresource.guarded async def rollback(self): """Exit the transaction or savepoint block and rollback changes.""" if self._managed: raise apg_errors.InterfaceError( 'cannot manually rollback from within an `async with` block') await self.__rollback() def __repr__(self): attrs = [] attrs.append('state:{}'.format(self._state.name.lower())) if self._isolation is not None: attrs.append(self._isolation) if self._readonly: attrs.append('readonly') if self._deferrable: attrs.append('deferrable') if self.__class__.__module__.startswith('asyncpg.'): mod = 'asyncpg' else: mod = self.__class__.__module__ return '<{}.{} {} {:#x}>'.format( mod, self.__class__.__name__, ' '.join(attrs), id(self)) ================================================ FILE: asyncpg/types.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import typing from asyncpg.pgproto.types import ( BitString, Point, Path, Polygon, Box, Line, LineSegment, Circle, ) if typing.TYPE_CHECKING: from typing_extensions import Self __all__ = ( 'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon', 'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion', ) class Type(typing.NamedTuple): oid: int name: str kind: str schema: str Type.__doc__ = 'Database data type.' Type.oid.__doc__ = 'OID of the type.' Type.name.__doc__ = 'Type name. For example "int2".' Type.kind.__doc__ = \ 'Type kind. Can be "scalar", "array", "composite" or "range".' Type.schema.__doc__ = 'Name of the database schema that defines the type.' class Attribute(typing.NamedTuple): name: str type: Type Attribute.__doc__ = 'Database relation attribute.' Attribute.name.__doc__ = 'Attribute name.' Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.' class ServerVersion(typing.NamedTuple): major: int minor: int micro: int releaselevel: str serial: int ServerVersion.__doc__ = 'PostgreSQL server version tuple.' class _RangeValue(typing.Protocol): def __eq__(self, __value: object) -> bool: ... def __lt__(self, __other: Self, /) -> bool: ... def __gt__(self, __other: Self, /) -> bool: ... _RV = typing.TypeVar('_RV', bound=_RangeValue) class Range(typing.Generic[_RV]): """Immutable representation of PostgreSQL `range` type.""" __slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty') _lower: _RV | None _upper: _RV | None _lower_inc: bool _upper_inc: bool _empty: bool def __init__( self, lower: _RV | None = None, upper: _RV | None = None, *, lower_inc: bool = True, upper_inc: bool = False, empty: bool = False ) -> None: self._empty = empty if empty: self._lower = self._upper = None self._lower_inc = self._upper_inc = False else: self._lower = lower self._upper = upper self._lower_inc = lower is not None and lower_inc self._upper_inc = upper is not None and upper_inc @property def lower(self) -> _RV | None: return self._lower @property def lower_inc(self) -> bool: return self._lower_inc @property def lower_inf(self) -> bool: return self._lower is None and not self._empty @property def upper(self) -> _RV | None: return self._upper @property def upper_inc(self) -> bool: return self._upper_inc @property def upper_inf(self) -> bool: return self._upper is None and not self._empty @property def isempty(self) -> bool: return self._empty def _issubset_lower(self, other: Self) -> bool: if other._lower is None: return True if self._lower is None: return False return self._lower > other._lower or ( self._lower == other._lower and (other._lower_inc or not self._lower_inc) ) def _issubset_upper(self, other: Self) -> bool: if other._upper is None: return True if self._upper is None: return False return self._upper < other._upper or ( self._upper == other._upper and (other._upper_inc or not self._upper_inc) ) def issubset(self, other: Self) -> bool: if self._empty: return True if other._empty: return False return self._issubset_lower(other) and self._issubset_upper(other) def issuperset(self, other: Self) -> bool: return other.issubset(self) def __bool__(self) -> bool: return not self._empty def __eq__(self, other: object) -> bool: if not isinstance(other, Range): return NotImplemented return ( self._lower, self._upper, self._lower_inc, self._upper_inc, self._empty ) == ( other._lower, # pyright: ignore [reportUnknownMemberType] other._upper, # pyright: ignore [reportUnknownMemberType] other._lower_inc, other._upper_inc, other._empty ) def __hash__(self) -> int: return hash(( self._lower, self._upper, self._lower_inc, self._upper_inc, self._empty )) def __repr__(self) -> str: if self._empty: desc = 'empty' else: if self._lower is None or not self._lower_inc: lb = '(' else: lb = '[' if self._lower is not None: lb += repr(self._lower) if self._upper is not None: ub = repr(self._upper) else: ub = '' if self._upper is None or not self._upper_inc: ub += ')' else: ub += ']' desc = '{}, {}'.format(lb, ub) return ''.format(desc) __str__ = __repr__ ================================================ FILE: asyncpg/utils.py ================================================ # Copyright (C) 2016-present the ayncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import re def _quote_ident(ident): return '"{}"'.format(ident.replace('"', '""')) def _quote_literal(string): return "'{}'".format(string.replace("'", "''")) async def _mogrify(conn, query, args): """Safely inline arguments to query text.""" # Introspect the target query for argument types and # build a list of safely-quoted fully-qualified type names. ps = await conn.prepare(query) paramtypes = [] for t in ps.get_parameters(): if t.name.endswith('[]'): pname = '_' + t.name[:-2] else: pname = t.name paramtypes.append('{}.{}'.format( _quote_ident(t.schema), _quote_ident(pname))) del ps # Use Postgres to convert arguments to text representation # by casting each value to text. cols = ['quote_literal(${}::{}::text)'.format(i, t) for i, t in enumerate(paramtypes, start=1)] textified = await conn.fetchrow( 'SELECT {cols}'.format(cols=', '.join(cols)), *args) # Finally, replace $n references with text values. return re.sub( r"\$(\d+)\b", lambda m: ( textified[int(m.group(1)) - 1] if textified[int(m.group(1)) - 1] is not None else "NULL" ), query, ) ================================================ FILE: docs/.gitignore ================================================ _build _templates ================================================ FILE: docs/Makefile ================================================ # Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = python -m sphinx PAPER = BUILDDIR = _build # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " applehelp to make an Apple Help Book" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " epub3 to make an epub3" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" @echo " coverage to run coverage check of the documentation (if enabled)" @echo " dummy to check syntax errors of document sources" .PHONY: clean clean: rm -rf $(BUILDDIR)/* .PHONY: html html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." .PHONY: dirhtml dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." .PHONY: singlehtml singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." .PHONY: pickle pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." .PHONY: json json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." .PHONY: htmlhelp htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." .PHONY: qthelp qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/asyncpg.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/asyncpg.qhc" .PHONY: applehelp applehelp: $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp @echo @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." @echo "N.B. You won't be able to view it unless you put it in" \ "~/Library/Documentation/Help or install it in your application" \ "bundle." .PHONY: devhelp devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/asyncpg" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/asyncpg" @echo "# devhelp" .PHONY: epub epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." .PHONY: epub3 epub3: $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 @echo @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." .PHONY: latex latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." .PHONY: latexpdf latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." .PHONY: latexpdfja latexpdfja: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." .PHONY: text text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." .PHONY: man man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." .PHONY: texinfo texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." .PHONY: info info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." .PHONY: gettext gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." .PHONY: changes changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." .PHONY: linkcheck linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." .PHONY: doctest doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." .PHONY: coverage coverage: $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage @echo "Testing of coverage in the sources finished, look at the " \ "results in $(BUILDDIR)/coverage/python.txt." .PHONY: xml xml: $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." .PHONY: pseudoxml pseudoxml: $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." .PHONY: dummy dummy: $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy @echo @echo "Build finished. Dummy builder generates no files." ================================================ FILE: docs/_static/theme_overrides.css ================================================ /* override table width restrictions */ @media screen and (min-width: 767px) { .wy-table-responsive table td { white-space: normal !important; vertical-align: top !important; } .wy-table-responsive { overflow: visible !important; } } ================================================ FILE: docs/api/index.rst ================================================ .. _asyncpg-api-reference: ============= API Reference ============= .. module:: asyncpg :synopsis: A fast PostgreSQL Database Client Library for Python/asyncio .. currentmodule:: asyncpg .. _asyncpg-api-connection: Connection ========== .. autofunction:: asyncpg.connection.connect .. autoclass:: asyncpg.connection.Connection :members: .. _asyncpg-api-prepared-stmt: Prepared Statements =================== Prepared statements are a PostgreSQL feature that can be used to optimize the performance of queries that are executed more than once. When a query is *prepared* by a call to :meth:`Connection.prepare`, the server parses, analyzes and compiles the query allowing to reuse that work once there is a need to run the same query again. .. code-block:: pycon >>> import asyncpg, asyncio >>> async def run(): ... conn = await asyncpg.connect() ... stmt = await conn.prepare('''SELECT 2 ^ $1''') ... print(await stmt.fetchval(10)) ... print(await stmt.fetchval(20)) ... >>> asyncio.run(run()) 1024.0 1048576.0 .. note:: asyncpg automatically maintains a small LRU cache for queries executed during calls to the :meth:`~Connection.fetch`, :meth:`~Connection.fetchrow`, or :meth:`~Connection.fetchval` methods. .. warning:: If you are using pgbouncer with ``pool_mode`` set to ``transaction`` or ``statement``, prepared statements will not work correctly. See :ref:`asyncpg-prepared-stmt-errors` for more information. .. autoclass:: asyncpg.prepared_stmt.PreparedStatement() :members: .. _asyncpg-api-transaction: Transactions ============ The most common way to use transactions is through an ``async with`` statement: .. code-block:: python async with connection.transaction(): await connection.execute("INSERT INTO mytable VALUES(1, 2, 3)") asyncpg supports nested transactions (a nested transaction context will create a `savepoint`_.): .. code-block:: python async with connection.transaction(): await connection.execute('CREATE TABLE mytab (a int)') try: # Create a nested transaction: async with connection.transaction(): await connection.execute('INSERT INTO mytab (a) VALUES (1), (2)') # This nested transaction will be automatically rolled back: raise Exception except: # Ignore exception pass # Because the nested transaction was rolled back, there # will be nothing in `mytab`. assert await connection.fetch('SELECT a FROM mytab') == [] Alternatively, transactions can be used without an ``async with`` block: .. code-block:: python tr = connection.transaction() await tr.start() try: ... except: await tr.rollback() raise else: await tr.commit() See also the :meth:`Connection.transaction() ` function. .. _savepoint: https://www.postgresql.org/docs/current/static/sql-savepoint.html .. autoclass:: asyncpg.transaction.Transaction() :members: .. describe:: async with c: start and commit/rollback the transaction or savepoint block automatically when entering and exiting the code inside the context manager block. .. _asyncpg-api-cursor: Cursors ======= Cursors are useful when there is a need to iterate over the results of a large query without fetching all rows at once. The cursor interface provided by asyncpg supports *asynchronous iteration* via the ``async for`` statement, and also a way to read row chunks and skip forward over the result set. To iterate over a cursor using a connection object use :meth:`Connection.cursor() `. To make the iteration efficient, the cursor will prefetch records to reduce the number of queries sent to the server: .. code-block:: python async def iterate(con: Connection): async with con.transaction(): # Postgres requires non-scrollable cursors to be created # and used in a transaction. async for record in con.cursor('SELECT generate_series(0, 100)'): print(record) Or, alternatively, you can iterate over the cursor manually (cursor won't be prefetching any rows): .. code-block:: python async def iterate(con: Connection): async with con.transaction(): # Postgres requires non-scrollable cursors to be created # and used in a transaction. # Create a Cursor object cur = await con.cursor('SELECT generate_series(0, 100)') # Move the cursor 10 rows forward await cur.forward(10) # Fetch one row and print it print(await cur.fetchrow()) # Fetch a list of 5 rows and print it print(await cur.fetch(5)) It's also possible to create cursors from prepared statements: .. code-block:: python async def iterate(con: Connection): # Create a prepared statement that will accept one argument stmt = await con.prepare('SELECT generate_series(0, $1)') async with con.transaction(): # Postgres requires non-scrollable cursors to be created # and used in a transaction. # Execute the prepared statement passing `10` as the # argument -- that will generate a series or records # from 0..10. Iterate over all of them and print every # record. async for record in stmt.cursor(10): print(record) .. note:: Cursors created by a call to :meth:`Connection.cursor() ` or :meth:`PreparedStatement.cursor() ` are *non-scrollable*: they can only be read forwards. To create a scrollable cursor, use the ``DECLARE ... SCROLL CURSOR`` SQL statement directly. .. warning:: Cursors created by a call to :meth:`Connection.cursor() ` or :meth:`PreparedStatement.cursor() ` cannot be used outside of a transaction. Any such attempt will result in :exc:`~asyncpg.exceptions.InterfaceError`. To create a cursor usable outside of a transaction, use the ``DECLARE ... CURSOR WITH HOLD`` SQL statement directly. .. autoclass:: asyncpg.cursor.CursorFactory() :members: .. describe:: async for row in c Execute the statement and iterate over the results asynchronously. .. describe:: await c Execute the statement and return an instance of :class:`~asyncpg.cursor.Cursor` which can be used to navigate over and fetch subsets of the query results. .. autoclass:: asyncpg.cursor.Cursor() :members: .. _asyncpg-api-pool: Connection Pools ================ .. autofunction:: asyncpg.pool.create_pool .. autoclass:: asyncpg.pool.Pool() :members: .. _asyncpg-api-record: Record Objects ============== Each row (or composite type value) returned by calls to ``fetch*`` methods is represented by an instance of the :class:`~asyncpg.Record` object. ``Record`` objects are a tuple-/dict-like hybrid, and allow addressing of items either by a numeric index or by a field name: .. code-block:: pycon >>> import asyncpg >>> import asyncio >>> loop = asyncio.get_event_loop() >>> conn = loop.run_until_complete(asyncpg.connect()) >>> r = loop.run_until_complete(conn.fetchrow(''' ... SELECT oid, rolname, rolsuper FROM pg_roles WHERE rolname = user''')) >>> r >>> r['oid'] 16388 >>> r[0] 16388 >>> dict(r) {'oid': 16388, 'rolname': 'elvis', 'rolsuper': True} >>> tuple(r) (16388, 'elvis', True) .. note:: ``Record`` objects currently cannot be created from Python code. .. class:: Record() A read-only representation of PostgreSQL row. .. describe:: len(r) Return the number of fields in record *r*. .. describe:: r[field] Return the field of *r* with field name or index *field*. .. describe:: name in r Return ``True`` if record *r* has a field named *name*. .. describe:: iter(r) Return an iterator over the *values* of the record *r*. .. describe:: get(name[, default]) Return the value for *name* if the record has a field named *name*, else return *default*. If *default* is not given, return ``None``. .. versionadded:: 0.18 .. method:: values() Return an iterator over the record values. .. method:: keys() Return an iterator over the record field names. .. method:: items() Return an iterator over ``(field, value)`` pairs. .. class:: ConnectionSettings() A read-only collection of Connection settings. .. describe:: settings.setting_name Return the value of the "setting_name" setting. Raises an ``AttributeError`` if the setting is not defined. Example: .. code-block:: pycon >>> connection.get_settings().client_encoding 'UTF8' Data Types ========== .. automodule:: asyncpg.types :members: ================================================ FILE: docs/conf.py ================================================ #!/usr/bin/env python3 import os import sys sys.path.insert(0, os.path.abspath('..')) version_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'asyncpg', '_version.py') with open(version_file, 'r') as f: for line in f: if line.startswith('__version__: typing.Final ='): _, _, version = line.partition('=') version = version.strip(" \n'\"") break else: raise RuntimeError( 'unable to read the version from asyncpg/_version.py') # -- General configuration ------------------------------------------------ extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', 'sphinx.ext.intersphinx', ] add_module_names = False templates_path = ['_templates'] source_suffix = '.rst' master_doc = 'index' project = 'asyncpg' copyright = '2016-present, the asyncpg authors and contributors' author = '' release = version language = "en" exclude_patterns = ['_build'] pygments_style = 'sphinx' todo_include_todos = False suppress_warnings = ['image.nonlocal_uri'] # -- Options for HTML output ---------------------------------------------- html_theme = 'sphinx_rtd_theme' html_title = 'asyncpg Documentation' html_short_title = 'asyncpg' html_static_path = ['_static'] html_sidebars = { '**': [ 'about.html', 'navigation.html', ] } html_show_sourcelink = False html_show_sphinx = False html_show_copyright = True htmlhelp_basename = 'asyncpgdoc' # -- Options for LaTeX output --------------------------------------------- latex_elements = {} latex_documents = [ (master_doc, 'asyncpg.tex', 'asyncpg Documentation', author, 'manual'), ] # -- Options for manual page output --------------------------------------- man_pages = [ (master_doc, 'asyncpg', 'asyncpg Documentation', [author], 1) ] # -- Options for Texinfo output ------------------------------------------- texinfo_documents = [ (master_doc, 'asyncpg', 'asyncpg Documentation', author, 'asyncpg', 'asyncpg is a fast PostgreSQL client library for the ' 'Python asyncio framework', 'Miscellaneous'), ] # -- Options for intersphinx ---------------------------------------------- intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} ================================================ FILE: docs/faq.rst ================================================ .. _asyncpg-faq: Frequently Asked Questions ========================== Does asyncpg support DB-API? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ No. DB-API is a synchronous API, while asyncpg is based around an asynchronous I/O model. Thus, full drop-in compatibility with DB-API is not possible and we decided to design asyncpg API in a way that is better aligned with PostgreSQL architecture and terminology. We will release a synchronous DB-API-compatible version of asyncpg at some point in the future. Can I use asyncpg with SQLAlchemy ORM? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Yes. SQLAlchemy version 1.4 and later supports the asyncpg dialect natively. Please refer to its documentation for details. Older SQLAlchemy versions may be used in tandem with a third-party adapter such as asyncpgsa_ or databases_. Can I use dot-notation with :class:`asyncpg.Record`? It looks cleaner. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We decided against making :class:`asyncpg.Record` a named tuple because we want to keep the ``Record`` method namespace separate from the column namespace. That said, you can provide a custom ``Record`` class that implements dot-notation via the ``record_class`` argument to :func:`connect() ` or any of the Record-returning methods. .. code-block:: python class MyRecord(asyncpg.Record): def __getattr__(self, name): return self[name] Why can't I use a :ref:`cursor ` outside of a transaction? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Cursors created by a call to :meth:`Connection.cursor() ` or :meth:`PreparedStatement.cursor() \ ` cannot be used outside of a transaction. Any such attempt will result in ``InterfaceError``. To create a cursor usable outside of a transaction, use the ``DECLARE ... CURSOR WITH HOLD`` SQL statement directly. .. _asyncpg-prepared-stmt-errors: Why am I getting prepared statement errors? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If you are getting intermittent ``prepared statement "__asyncpg_stmt_xx__" does not exist`` or ``prepared statement “__asyncpg_stmt_xx__” already exists`` errors, you are most likely not connecting to the PostgreSQL server directly, but via `pgbouncer `_. pgbouncer, when in the ``"transaction"`` or ``"statement"`` pooling mode, does not support prepared statements. You have several options: * if you are using pgbouncer only to reduce the cost of new connections (as opposed to using pgbouncer for connection pooling from a large number of clients in the interest of better scalability), switch to the :ref:`connection pool ` functionality provided by asyncpg, it is a much better option for this purpose; * disable automatic use of prepared statements by passing ``statement_cache_size=0`` to :func:`asyncpg.connect() ` and :func:`asyncpg.create_pool() ` (and, obviously, avoid the use of :meth:`Connection.prepare() `); * switch pgbouncer's ``pool_mode`` to ``session``. Why do I get ``PostgresSyntaxError`` when using ``expression IN $1``? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ``expression IN $1`` is not a valid PostgreSQL syntax. To check a value against a sequence use ``expression = any($1::mytype[])``, where ``mytype`` is the array element type. .. _asyncpgsa: https://github.com/CanopyTax/asyncpgsa .. _databases: https://github.com/encode/databases ================================================ FILE: docs/index.rst ================================================ .. image:: https://github.com/MagicStack/asyncpg/workflows/Tests/badge.svg :target: https://github.com/MagicStack/asyncpg/actions?query=workflow%3ATests+branch%3Amaster :alt: GitHub Actions status .. image:: https://img.shields.io/pypi/status/asyncpg.svg?maxAge=2592000?style=plastic :target: https://pypi.python.org/pypi/asyncpg ======= asyncpg ======= **asyncpg** is a database interface library designed specifically for PostgreSQL and Python/asyncio. asyncpg is an efficient, clean implementation of PostgreSQL server binary protocol for use with Python's ``asyncio`` framework. **asyncpg** requires Python 3.9 or later and is supported for PostgreSQL versions 9.5 to 18. Other PostgreSQL versions or other databases implementing the PostgreSQL protocol *may* work, but are not being actively tested. Contents -------- .. toctree:: :maxdepth: 2 installation usage api/index faq ================================================ FILE: docs/installation.rst ================================================ .. _asyncpg-installation: Installation ============ **asyncpg** has no external dependencies when not using GSSAPI/SSPI authentication. The recommended way to install it is to use **pip**: .. code-block:: bash $ pip install asyncpg If you need GSSAPI/SSPI authentication, the recommended way is to use .. code-block:: bash $ pip install 'asyncpg[gssauth]' This installs SSPI support on Windows and GSSAPI support on non-Windows platforms. SSPI and GSSAPI interoperate as clients and servers: an SSPI client can authenticate to a GSSAPI server and vice versa. On Linux installing GSSAPI requires a working C compiler and Kerberos 5 development files. The latter can be obtained by installing **libkrb5-dev** package on Debian/Ubuntu or **krb5-devel** on RHEL/Fedora. (This is needed because PyPI does not have Linux wheels for **gssapi**. See `here for the details `_.) It is also possible to use GSSAPI on Windows: * `pip install gssapi` * Install `Kerberos for Windows `_. * Set the ``gsslib`` parameter or the ``PGGSSLIB`` environment variable to `gssapi` when connecting. Building from source -------------------- If you want to build **asyncpg** from a Git checkout you will need: * To have cloned the repo with `--recurse-submodules`. * A working C compiler. * CPython header files. These can usually be obtained by installing the relevant Python development package: **python3-dev** on Debian/Ubuntu, **python3-devel** on RHEL/Fedora. Once the above requirements are satisfied, run the following command in the root of the source checkout: .. code-block:: bash $ pip install -e . A debug build containing more runtime checks can be created by setting the ``ASYNCPG_DEBUG`` environment variable when building: .. code-block:: bash $ env ASYNCPG_DEBUG=1 pip install -e . Running tests ------------- If you want to run tests you must have PostgreSQL installed. To execute the testsuite run: .. code-block:: bash $ python setup.py test ================================================ FILE: docs/requirements.txt ================================================ sphinxcontrib-asyncio sphinx_rtd_theme ================================================ FILE: docs/usage.rst ================================================ .. _asyncpg-examples: asyncpg Usage ============= The interaction with the database normally starts with a call to :func:`connect() `, which establishes a new database session and returns a new :class:`Connection ` instance, which provides methods to run queries and manage transactions. .. code-block:: python import asyncio import asyncpg import datetime async def main(): # Establish a connection to an existing database named "test" # as a "postgres" user. conn = await asyncpg.connect('postgresql://postgres@localhost/test') # Execute a statement to create a new table. await conn.execute(''' CREATE TABLE users( id serial PRIMARY KEY, name text, dob date ) ''') # Insert a record into the created table. await conn.execute(''' INSERT INTO users(name, dob) VALUES($1, $2) ''', 'Bob', datetime.date(1984, 3, 1)) # Select a row from the table. row = await conn.fetchrow( 'SELECT * FROM users WHERE name = $1', 'Bob') # *row* now contains # asyncpg.Record(id=1, name='Bob', dob=datetime.date(1984, 3, 1)) # Close the connection. await conn.close() asyncio.run(main()) .. note:: asyncpg uses the native PostgreSQL syntax for query arguments: ``$n``. Type Conversion --------------- asyncpg automatically converts PostgreSQL types to the corresponding Python types and vice versa. All standard data types are supported out of the box, including arrays, composite types, range types, enumerations and any combination of them. It is possible to supply codecs for non-standard types or override standard codecs. See :ref:`asyncpg-custom-codecs` for more information. The table below shows the correspondence between PostgreSQL and Python types. +----------------------+-----------------------------------------------------+ | PostgreSQL Type | Python Type | +======================+=====================================================+ | ``anyarray`` | :class:`list ` | +----------------------+-----------------------------------------------------+ | ``anyenum`` | :class:`str ` | +----------------------+-----------------------------------------------------+ | ``anyrange`` | :class:`asyncpg.Range `, | | | :class:`tuple ` | +----------------------+-----------------------------------------------------+ | ``anymultirange`` | ``list[``:class:`asyncpg.Range\ | | | ` ``]``, | | | ``list[``:class:`tuple ` ``]`` [#f1]_ | +----------------------+-----------------------------------------------------+ | ``record`` | :class:`asyncpg.Record`, | | | :class:`tuple `, | | | :class:`Mapping ` | +----------------------+-----------------------------------------------------+ | ``bit``, ``varbit`` | :class:`asyncpg.BitString `| +----------------------+-----------------------------------------------------+ | ``bool`` | :class:`bool ` | +----------------------+-----------------------------------------------------+ | ``box`` | :class:`asyncpg.Box ` | +----------------------+-----------------------------------------------------+ | ``bytea`` | :class:`bytes ` | +----------------------+-----------------------------------------------------+ | ``char``, ``name``, | :class:`str ` | | ``varchar``, | | | ``text``, | | | ``xml`` | | +----------------------+-----------------------------------------------------+ | ``cidr`` | :class:`ipaddress.IPv4Network\ | | | `, | | | :class:`ipaddress.IPv6Network\ | | | ` | +----------------------+-----------------------------------------------------+ | ``inet`` | :class:`ipaddress.IPv4Interface\ | | | `, | | | :class:`ipaddress.IPv6Interface\ | | | `, | | | :class:`ipaddress.IPv4Address\ | | | `, | | | :class:`ipaddress.IPv6Address\ | | | ` [#f2]_ | +----------------------+-----------------------------------------------------+ | ``macaddr`` | :class:`str ` | +----------------------+-----------------------------------------------------+ | ``circle`` | :class:`asyncpg.Circle ` | +----------------------+-----------------------------------------------------+ | ``date`` | :class:`datetime.date ` | +----------------------+-----------------------------------------------------+ | ``time`` | offset-naïve :class:`datetime.time \ | | | ` | +----------------------+-----------------------------------------------------+ | ``time with | offset-aware :class:`datetime.time \ | | time zone`` | ` | +----------------------+-----------------------------------------------------+ | ``timestamp`` | offset-naïve :class:`datetime.datetime \ | | | ` | +----------------------+-----------------------------------------------------+ | ``timestamp with | offset-aware :class:`datetime.datetime \ | | time zone`` | ` | +----------------------+-----------------------------------------------------+ | ``interval`` | :class:`datetime.timedelta \ | | | ` | +----------------------+-----------------------------------------------------+ | ``float``, | :class:`float ` [#f3]_ | | ``double precision`` | | +----------------------+-----------------------------------------------------+ | ``smallint``, | :class:`int ` | | ``integer``, | | | ``bigint`` | | +----------------------+-----------------------------------------------------+ | ``numeric`` | :class:`Decimal ` | +----------------------+-----------------------------------------------------+ | ``json``, ``jsonb`` | :class:`str ` | +----------------------+-----------------------------------------------------+ | ``line`` | :class:`asyncpg.Line ` | +----------------------+-----------------------------------------------------+ | ``lseg`` | :class:`asyncpg.LineSegment \ | | | ` | +----------------------+-----------------------------------------------------+ | ``money`` | :class:`str ` | +----------------------+-----------------------------------------------------+ | ``path`` | :class:`asyncpg.Path ` | +----------------------+-----------------------------------------------------+ | ``point`` | :class:`asyncpg.Point ` | +----------------------+-----------------------------------------------------+ | ``polygon`` | :class:`asyncpg.Polygon ` | +----------------------+-----------------------------------------------------+ | ``uuid`` | :class:`uuid.UUID ` | +----------------------+-----------------------------------------------------+ | ``tid`` | :class:`tuple ` | +----------------------+-----------------------------------------------------+ All other types are encoded and decoded as text by default. .. [#f1] Since version 0.25.0 .. [#f2] Prior to version 0.20.0, asyncpg erroneously treated ``inet`` values with prefix as ``IPvXNetwork`` instead of ``IPvXInterface``. .. [#f3] Inexact single-precision ``float`` values may have a different representation when decoded into a Python float. This is inherent to the implementation of limited-precision floating point types. If you need the decimal representation to match, cast the expression to ``double`` or ``numeric`` in your query. .. _asyncpg-custom-codecs: Custom Type Conversions ----------------------- asyncpg allows defining custom type conversion functions both for standard and user-defined types using the :meth:`Connection.set_type_codec() \ ` and :meth:`Connection.set_builtin_type_codec() \ ` methods. Example: automatic JSON conversion ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The example below shows how to configure asyncpg to encode and decode JSON values using the :mod:`json ` module. .. code-block:: python import asyncio import asyncpg import json async def main(): conn = await asyncpg.connect() try: await conn.set_type_codec( 'json', encoder=json.dumps, decoder=json.loads, schema='pg_catalog' ) data = {'foo': 'bar', 'spam': 1} res = await conn.fetchval('SELECT $1::json', data) finally: await conn.close() asyncio.run(main()) Example: complex types ~~~~~~~~~~~~~~~~~~~~~~ The example below shows how to configure asyncpg to encode and decode Python :class:`complex ` values to a custom composite type in PostgreSQL. .. code-block:: python import asyncio import asyncpg async def main(): conn = await asyncpg.connect() try: await conn.execute( ''' CREATE TYPE mycomplex AS ( r float, i float );''' ) await conn.set_type_codec( 'complex', encoder=lambda x: (x.real, x.imag), decoder=lambda t: complex(t[0], t[1]), format='tuple', ) res = await conn.fetchval('SELECT $1::mycomplex', (1+2j)) finally: await conn.close() asyncio.run(main()) Example: automatic conversion of PostGIS types ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The example below shows how to configure asyncpg to encode and decode the PostGIS ``geometry`` type. It works for any Python object that conforms to the `geo interface specification`_ and relies on Shapely_, although any library that supports reading and writing the WKB format will work. .. _Shapely: https://github.com/Toblerity/Shapely .. _geo interface specification: https://gist.github.com/sgillies/2217756 .. code-block:: python import asyncio import asyncpg import shapely.geometry import shapely.wkb from shapely.geometry.base import BaseGeometry async def main(): conn = await asyncpg.connect() try: def encode_geometry(geometry): if not hasattr(geometry, '__geo_interface__'): raise TypeError('{g} does not conform to ' 'the geo interface'.format(g=geometry)) shape = shapely.geometry.shape(geometry) return shapely.wkb.dumps(shape) def decode_geometry(wkb): return shapely.wkb.loads(wkb) await conn.set_type_codec( 'geometry', # also works for 'geography' encoder=encode_geometry, decoder=decode_geometry, format='binary', ) data = shapely.geometry.Point(-73.985661, 40.748447) res = await conn.fetchrow( '''SELECT 'Empire State Building' AS name, $1::geometry AS coordinates ''', data) print(res) finally: await conn.close() asyncio.run(main()) Example: decoding numeric columns as floats ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ By default asyncpg decodes numeric columns as Python :class:`Decimal ` instances. The example below shows how to instruct asyncpg to use floats instead. .. code-block:: python import asyncio import asyncpg async def main(): conn = await asyncpg.connect() try: await conn.set_type_codec( 'numeric', encoder=str, decoder=float, schema='pg_catalog', format='text' ) res = await conn.fetchval("SELECT $1::numeric", 11.123) print(res, type(res)) finally: await conn.close() asyncio.run(main()) Example: decoding hstore values ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ hstore_ is an extension data type used for storing key/value pairs. asyncpg includes a codec to decode and encode hstore values as ``dict`` objects. Because ``hstore`` is not a builtin type, the codec must be registered on a connection using :meth:`Connection.set_builtin_type_codec() `: .. code-block:: python import asyncpg import asyncio async def run(): conn = await asyncpg.connect() # Assuming the hstore extension exists in the public schema. await conn.set_builtin_type_codec( 'hstore', codec_name='pg_contrib.hstore') result = await conn.fetchval("SELECT 'a=>1,b=>2,c=>NULL'::hstore") assert result == {'a': '1', 'b': '2', 'c': None} asyncio.run(run()) .. _hstore: https://www.postgresql.org/docs/current/static/hstore.html Transactions ------------ To create transactions, the :meth:`Connection.transaction() ` method should be used. The most common way to use transactions is through an ``async with`` statement: .. code-block:: python async with connection.transaction(): await connection.execute("INSERT INTO mytable VALUES(1, 2, 3)") .. note:: When not in an explicit transaction block, any changes to the database will be applied immediately. This is also known as *auto-commit*. See the :ref:`asyncpg-api-transaction` API documentation for more information. .. _asyncpg-connection-pool: Connection Pools ---------------- For server-type type applications, that handle frequent requests and need the database connection for a short period time while handling a request, the use of a connection pool is recommended. asyncpg provides an advanced pool implementation, which eliminates the need to use an external connection pooler such as PgBouncer. To create a connection pool, use the :func:`asyncpg.create_pool() ` function. The resulting :class:`Pool ` object can then be used to borrow connections from the pool. Below is an example of how **asyncpg** can be used to implement a simple Web service that computes the requested power of two. .. code-block:: python import asyncio import asyncpg from aiohttp import web async def handle(request): """Handle incoming requests.""" pool = request.app['pool'] power = int(request.match_info.get('power', 10)) # Take a connection from the pool. async with pool.acquire() as connection: # Open a transaction. async with connection.transaction(): # Run the query passing the request argument. result = await connection.fetchval('select 2 ^ $1', power) return web.Response( text="2 ^ {} is {}".format(power, result)) async def init_db(app): """Initialize a connection pool.""" app['pool'] = await asyncpg.create_pool(database='postgres', user='postgres') yield await app['pool'].close() def init_app(): """Initialize the application server.""" app = web.Application() # Create a database context app.cleanup_ctx.append(init_db) # Configure service routes app.router.add_route('GET', '/{power:\d+}', handle) app.router.add_route('GET', '/', handle) return app app = init_app() web.run_app(app) See :ref:`asyncpg-api-pool` API documentation for more information. ================================================ FILE: pyproject.toml ================================================ [project] name = "asyncpg" description = "An asyncio PostgreSQL driver" authors = [{name = "MagicStack Inc", email = "hello@magic.io"}] requires-python = '>=3.9.0' readme = "README.rst" license = "Apache-2.0" license-files = ["LICENSE"] dynamic = ["version"] keywords = [ "database", "postgres", ] classifiers = [ "Development Status :: 5 - Production/Stable", "Framework :: AsyncIO", "Intended Audience :: Developers", "Operating System :: POSIX", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Free Threading :: 2 - Beta", "Topic :: Database :: Front-Ends", ] dependencies = [ 'async_timeout>=4.0.3; python_version < "3.11.0"', ] [project.urls] github = "https://github.com/MagicStack/asyncpg" [project.optional-dependencies] gssauth = [ 'gssapi; platform_system != "Windows"', 'sspilib; platform_system == "Windows"', ] [dependency-groups] test = [ 'flake8~=6.1', 'flake8-pyi~=24.1.0', 'distro~=1.9.0', 'uvloop>=0.22.1; platform_system != "Windows" and python_version < "3.15.0"', 'gssapi; platform_system == "Linux"', 'k5test; platform_system == "Linux"', 'sspilib; platform_system == "Windows"', 'mypy~=1.8.0', 'pytest', ] docs = [ 'Sphinx~=7.4', 'sphinx_rtd_theme>=1.2.2', ] [build-system] requires = [ "setuptools>=77.0.3", "Cython(>=3.2.1,<4.0.0)" ] build-backend = "setuptools.build_meta" [tool.setuptools] zip-safe = false [tool.setuptools.packages.find] include = ["asyncpg", "asyncpg.*"] [tool.setuptools.exclude-package-data] "*" = ["*.c", "*.h"] [tool.cibuildwheel] build-frontend = "build" test-groups = "test" skip = "cp38-*" [tool.cibuildwheel.macos] before-all = ".github/workflows/install-postgres.sh" test-command = "python {project}/tests/__init__.py" [tool.cibuildwheel.windows] test-command = "python {project}\\tests\\__init__.py" [tool.cibuildwheel.linux] before-all = """ .github/workflows/install-postgres.sh \ && .github/workflows/install-krb5.sh \ """ test-command = """\ PY=`which python` \ && chmod -R go+rX "$(dirname $(dirname $(dirname $PY)))" \ && su -l apgtest -c "$PY {project}/tests/__init__.py" \ """ [tool.pytest.ini_options] addopts = "--capture=no --assert=plain --strict-markers --tb=native --import-mode=importlib" testpaths = "tests" filterwarnings = "default" [tool.coverage.run] branch = true plugins = ["Cython.Coverage"] parallel = true source = ["asyncpg/", "tests/"] omit = ["*.pxd"] [tool.coverage.report] exclude_lines = [ "pragma: no cover", "def __repr__", "if debug", "raise NotImplementedError", "if __name__ == .__main__.", ] show_missing = true [tool.mypy] exclude = [ "^.eggs", "^.github", "^.vscode", "^build", "^dist", "^docs", "^tests", ] incremental = true strict = true implicit_reexport = true [[tool.mypy.overrides]] module = [ "asyncpg._testbase", "asyncpg._testbase.*", "asyncpg.cluster", "asyncpg.connect_utils", "asyncpg.connection", "asyncpg.connresource", "asyncpg.cursor", "asyncpg.exceptions", "asyncpg.exceptions.*", "asyncpg.pool", "asyncpg.prepared_stmt", "asyncpg.transaction", "asyncpg.utils", ] ignore_errors = true ================================================ FILE: setup.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import sys if sys.version_info < (3, 9): raise RuntimeError('asyncpg requires Python 3.9 or greater') import os import os.path import pathlib import platform import re import subprocess # We use vanilla build_ext, to avoid importing Cython via # the setuptools version. import setuptools from setuptools.command import build_py as setuptools_build_py from setuptools.command import sdist as setuptools_sdist from setuptools.command import build_ext as setuptools_build_ext CYTHON_DEPENDENCY = 'Cython(>=3.2.1,<4.0.0)' CFLAGS = ['-O2'] LDFLAGS = [] if platform.uname().system != 'Windows': CFLAGS.extend(['-fsigned-char', '-Wall', '-Wsign-compare', '-Wconversion']) # Link against libm (math library) for functions like log10() LDFLAGS.extend(['-lm']) _ROOT = pathlib.Path(__file__).parent with open(str(_ROOT / 'README.rst')) as f: readme = f.read() with open(str(_ROOT / 'asyncpg' / '_version.py')) as f: for line in f: if line.startswith('__version__: typing.Final ='): _, _, version = line.partition('=') VERSION = version.strip(" \n'\"") break else: raise RuntimeError( 'unable to read the version from asyncpg/_version.py') if (_ROOT / '.git').is_dir() and 'dev' in VERSION: # This is a git checkout, use git to # generate a precise version. def git_commitish(): env = {} v = os.environ.get('PATH') if v is not None: env['PATH'] = v git = subprocess.run(['git', 'rev-parse', 'HEAD'], env=env, cwd=str(_ROOT), stdout=subprocess.PIPE) if git.returncode == 0: commitish = git.stdout.strip().decode('ascii') else: commitish = 'unknown' return commitish VERSION += '+' + git_commitish()[:7] class VersionMixin: def _fix_version(self, filename): # Replace asyncpg.__version__ with the actual version # of the distribution (possibly inferred from git). with open(str(filename)) as f: content = f.read() version_re = r"(.*__version__\s*=\s*)'[^']+'(.*)" repl = r"\1'{}'\2".format(self.distribution.metadata.version) content = re.sub(version_re, repl, content) with open(str(filename), 'w') as f: f.write(content) class sdist(setuptools_sdist.sdist, VersionMixin): def make_release_tree(self, base_dir, files): super().make_release_tree(base_dir, files) self._fix_version(pathlib.Path(base_dir) / 'asyncpg' / '_version.py') class build_py(setuptools_build_py.build_py, VersionMixin): def build_module(self, module, module_file, package): outfile, copied = super().build_module(module, module_file, package) if module == '__init__' and package == 'asyncpg': self._fix_version(outfile) return outfile, copied class build_ext(setuptools_build_ext.build_ext): user_options = setuptools_build_ext.build_ext.user_options + [ ('cython-always', None, 'run cythonize() even if .c files are present'), ('cython-annotate', None, 'Produce a colorized HTML version of the Cython source.'), ('cython-directives=', None, 'Cython compiler directives'), ] def initialize_options(self): # initialize_options() may be called multiple times on the # same command object, so make sure not to override previously # set options. if getattr(self, '_initialized', False): return super(build_ext, self).initialize_options() defines = [ "CYTHON_USE_MODULE_STATE", "CYTHON_PEP489_MULTI_PHASE_INIT", "CYTHON_USE_TYPE_SPECS", ] if os.environ.get('ASYNCPG_DEBUG'): self.cython_always = True self.cython_annotate = True self.cython_directives = "linetrace=True" self.debug = True defines += ["PG_DEBUG", "CYTHON_TRACE", "CYTHON_TRACE_NOGIL"] else: self.cython_always = False self.cython_annotate = None self.cython_directives = None self.define = ",".join(defines) def finalize_options(self): # finalize_options() may be called multiple times on the # same command object, so make sure not to override previously # set options. if getattr(self, '_initialized', False): return if not self.cython_always: self.cython_always = bool(os.environ.get( "ASYNCPG_BUILD_CYTHON_ALWAYS")) if self.cython_annotate is None: self.cython_annotate = os.environ.get( "ASYNCPG_BUILD_CYTHON_ANNOTATE") if self.cython_directives is None: self.cython_directives = os.environ.get( "ASYNCPG_BUILD_CYTHON_DIRECTIVES") need_cythonize = self.cython_always cfiles = {} for extension in self.distribution.ext_modules: for i, sfile in enumerate(extension.sources): if sfile.endswith('.pyx'): prefix, ext = os.path.splitext(sfile) cfile = prefix + '.c' if os.path.exists(cfile) and not self.cython_always: extension.sources[i] = cfile else: if os.path.exists(cfile): cfiles[cfile] = os.path.getmtime(cfile) else: cfiles[cfile] = 0 need_cythonize = True if need_cythonize: import pkg_resources # Double check Cython presence in case setup_requires # didn't go into effect (most likely because someone # imported Cython before setup_requires injected the # correct egg into sys.path. try: import Cython except ImportError: raise RuntimeError( 'please install {} to compile asyncpg from source'.format( CYTHON_DEPENDENCY)) cython_dep = pkg_resources.Requirement.parse(CYTHON_DEPENDENCY) if Cython.__version__ not in cython_dep: raise RuntimeError( 'asyncpg requires {}, got Cython=={}'.format( CYTHON_DEPENDENCY, Cython.__version__ )) from Cython.Build import cythonize directives = { 'language_level': '3', 'freethreading_compatible': 'True', 'subinterpreters_compatible': 'own_gil', } if self.cython_directives: for directive in self.cython_directives.split(','): k, _, v = directive.partition('=') if v.lower() == 'false': v = False if v.lower() == 'true': v = True directives[k] = v self.distribution.ext_modules[:] = cythonize( self.distribution.ext_modules, compiler_directives=directives, annotate=self.cython_annotate) super(build_ext, self).finalize_options() setup_requires = [] if ( not (_ROOT / 'asyncpg' / 'protocol' / 'protocol.c').exists() or os.environ.get("ASYNCPG_BUILD_CYTHON_ALWAYS") ): # No Cython output, require Cython to build. setup_requires.append(CYTHON_DEPENDENCY) _ = setuptools.setup( version=VERSION, ext_modules=[ setuptools.extension.Extension( "asyncpg.pgproto.pgproto", ["asyncpg/pgproto/pgproto.pyx"], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS), setuptools.extension.Extension( "asyncpg.protocol.record", ["asyncpg/protocol/record/recordobj.c"], include_dirs=['asyncpg/protocol/record/'], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS), setuptools.extension.Extension( "asyncpg.protocol.protocol", ["asyncpg/protocol/protocol.pyx"], include_dirs=['asyncpg/pgproto/'], extra_compile_args=CFLAGS, extra_link_args=LDFLAGS), ], cmdclass={'build_ext': build_ext, 'build_py': build_py, 'sdist': sdist}, setup_requires=setup_requires, ) ================================================ FILE: tests/__init__.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import pathlib import sys import unittest def suite(): test_loader = unittest.TestLoader() test_suite = test_loader.discover(str(pathlib.Path(__file__).parent), pattern='test_*.py') return test_suite if __name__ == '__main__': runner = unittest.runner.TextTestRunner(verbosity=2) result = runner.run(suite()) sys.exit(not result.wasSuccessful()) ================================================ FILE: tests/certs/ca.cert.pem ================================================ -----BEGIN CERTIFICATE----- MIIGJjCCBA6gAwIBAgIICJCUmtkcj2MwDQYJKoZIhvcNAQELBQAwgaExCzAJBgNV BAYTAkNBMRAwDgYDVQQIDAdPbnRhcmlvMRAwDgYDVQQHDAdUb3JvbnRvMRgwFgYD VQQKDA9NYWdpY1N0YWNrIEluYy4xFjAUBgNVBAsMDWFzeW5jcGcgdGVzdHMxHTAb BgNVBAMMFGFzeW5jcGcgdGVzdCByb290IGNhMR0wGwYJKoZIhvcNAQkBFg5oZWxs b0BtYWdpYy5pbzAeFw0yNDEwMTYxNzIzNTZaFw00MzEyMTcxNzIzNTZaMIGhMQsw CQYDVQQGEwJDQTEQMA4GA1UECAwHT250YXJpbzEQMA4GA1UEBwwHVG9yb250bzEY MBYGA1UECgwPTWFnaWNTdGFjayBJbmMuMRYwFAYDVQQLDA1hc3luY3BnIHRlc3Rz MR0wGwYDVQQDDBRhc3luY3BnIHRlc3Qgcm9vdCBjYTEdMBsGCSqGSIb3DQEJARYO aGVsbG9AbWFnaWMuaW8wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCP +oCl0qrReSlWj+yvfGz68UQqm6joL9VgeA0Tvc8S23Ia3S73wcTTdGhIQwMOaIuW y+m3J3js2wtpF0fmULYHr1ED7vQ+QOWarTyv/cGxSCyOYo4KVPHBfT6lYQTJk5NW Oc2wr5ff/9nhdO61sGxZa2GVBjmbLOJ9IBKTvRcmNgLmPo60wMHtF4L5/PuwVPuu +zRoETfEh12avtY7Y2G+0i4ZRm4uBmw7hmByWzWCwqrV619BaFHaJUf2bEh5eCbz 1nhF7WHVjBfnSJOgDxmZbKZZPmNzTVm8UxN22g9Ao6cZSxjbFAdpIhlQhAT6sjlW hvI6b58A3AJKi7zo+a7lnbPIeckduSkgbil3LZ4KxWgx6fPCBLqGH1XN6I8MQnX/ e1ewiFXwuZMb+FgoKxaQBseuPVaA3ViYefysjvLjP7U9eRzv6qRimOmH5efaplbD zGhRUKA8GgmN/B+S3ofqDhpp3zz7gFxjkE1f4/XNACqXt79iGaH+EscV4znxlsZj gUQYAcExpAmKrJg5kmxagHcgu0pVKlyUvSba/kKQ/aYDgdddgPutH+UHs5pssc69 YBpEXQTG9CMeRh6ZUgcrR0foJLM5g2k53xpG1oTHiJcCKARFZPRpDoZ6NjCIuFKY 6+HMcpFRVDsDnUXmFah9bUhsSQbc6MHHX/iTbpMGNwIDAQABo2AwXjAPBgNVHRMB Af8EBTADAQH/MAsGA1UdDwQEAwIBBjAdBgNVHQ4EFgQUhGQbAW97KXQs68Z3efEj 55zsc4UwHwYDVR0jBBgwFoAUhGQbAW97KXQs68Z3efEj55zsc4UwDQYJKoZIhvcN AQELBQADggIBADsy7jhBmwGbOZPox0XvB2XzWjOPl3uI3Ys3uGaAXVbGVnP3nDtU waGg7Fhf/ibQVAOkWLfm9FCJEO6bEojF4CjCa//iMqXgnPJaWeYceb8+CzuF5Ukg n/kfbj04dVvOnPa8KYkMOWQ6zsBgKuNaA5jOKWYwoHFgQNjKRiVikyOp6zF3aPu0 wW7M7FOVHn0ZhMRBcJG8dGbQ8vaeu8z4i04tlvpQaFgtY66ECeUwhTIrvVuqtQOl jR//w70TUTIH3JzzYmyCubOCjdqcNRYPRRiA/L+mdzrE7honSTQfo0iupT/5bJcu GRjLHL/aRvYrq8ogqQKIYW0EbVuFzHfb+kPV61Bf5APbA26GU/14XkA4KwzJnDMR d2wr0RivSceXtY2ZakYP6+2cqjuhk6Y0tl0FBuyQXqAbe1L7X2VctLJMi5UgksVB q5rdHSJ3fbHRoCUpj4/rSafqJNHlAf2MEE/q8l0D8JhYoN69RhvyFQJLFEU4c74b XHdFt6bfyxm4+ZzUdj/TXadPAUO1YfQCn9Tf7QOoR68acSvQxEDbChZlJYkdAE+C zxNcoHVc6XIpk7NIr09qTQ5viz736fV6EI6OIoUaqrz9u+NZ3sPPD2Gf+rOinVFQ R2Q5kxHYo8Kt1DK0fFcUe1cOZk3df7seQWw1OdJngp5S7gEWBiWg8zr7 -----END CERTIFICATE----- ================================================ FILE: tests/certs/ca.crl.pem ================================================ -----BEGIN X509 CRL----- MIIDAjCB6wIBATANBgkqhkiG9w0BAQsFADCBoTELMAkGA1UEBhMCQ0ExEDAOBgNV BAgMB09udGFyaW8xEDAOBgNVBAcMB1Rvcm9udG8xGDAWBgNVBAoMD01hZ2ljU3Rh Y2sgSW5jLjEWMBQGA1UECwwNYXN5bmNwZyB0ZXN0czEdMBsGA1UEAwwUYXN5bmNw ZyB0ZXN0IHJvb3QgY2ExHTAbBgkqhkiG9w0BCQEWDmhlbGxvQG1hZ2ljLmlvFw0y MTA5MTQxNjA2MDFaFw0yMTA5MTUxNjA2MDFaMBUwEwICEAAXDTIxMDkxNDE2MDYw MVowDQYJKoZIhvcNAQELBQADggIBAL4yfNmvGS8SkIVbRzdAC9+XJPw/dBJOUJwr EgERICAz7OTqG1PkmMhPL00Dm9fe52+KnSwHgL749W0S/X5rTNMSwLyGiiJ5HYbH GFRKQ/cvXLi4jYpSI1Ac94kk0japf3SfwEw3+122oba8SiAVP0nY3bHpHvNfOaDV fhbFTwb5bFm6ThqlKLZxGCKP0fGeQ4homuwgRiLE/UOiue5ted1ph0PkKVui208k FnhNYXSllakTGT8ZZZZVid/4tSHqJEY9vbdMXNv1GX8mhjoU1Gv9dOuyFGgUc9Vx e7gzf/Wf36vKI29o8QGkkTslRZpMG59z3sG4Y0vJEoqXMB6eQLOr5iUCyj2CyDha 66pwrdc1fRt3EvNXUWkdHfY3EHb7DxueedDEgtmfSNbEaZTXa5RaZRavNGNTaPDf UcrDU4w1N0wkYLQxPqd+VPcf1iKyfkAydpeOq9CChqRD0Tx58eTn6N/lLGFPPRfs x47BA4FmefBeXZzd5HiXCUouk3qHIHs2yCzFs+TEBkx5eV42cP++HxjirPydLf6Y G/o/TKRnc/2Lw+dCzvUV/p3geuw4+vq1BIFanwB9jp4tGaBrffIAyle8vPQLw6bp 1o1O39pdxniz+c9r0Kw/ETxTqRLbasSib5FHq5G/G9a+QxPsLAzKgwLWhR4fXvbu YPbhYhRP -----END X509 CRL----- ================================================ FILE: tests/certs/ca.key.pem ================================================ -----BEGIN RSA PRIVATE KEY----- MIIJKAIBAAKCAgEAj/qApdKq0XkpVo/sr3xs+vFEKpuo6C/VYHgNE73PEttyGt0u 98HE03RoSEMDDmiLlsvptyd47NsLaRdH5lC2B69RA+70PkDlmq08r/3BsUgsjmKO ClTxwX0+pWEEyZOTVjnNsK+X3//Z4XTutbBsWWthlQY5myzifSASk70XJjYC5j6O tMDB7ReC+fz7sFT7rvs0aBE3xIddmr7WO2NhvtIuGUZuLgZsO4Zgcls1gsKq1etf QWhR2iVH9mxIeXgm89Z4Re1h1YwX50iToA8ZmWymWT5jc01ZvFMTdtoPQKOnGUsY 2xQHaSIZUIQE+rI5VobyOm+fANwCSou86Pmu5Z2zyHnJHbkpIG4pdy2eCsVoMenz wgS6hh9VzeiPDEJ1/3tXsIhV8LmTG/hYKCsWkAbHrj1WgN1YmHn8rI7y4z+1PXkc 7+qkYpjph+Xn2qZWw8xoUVCgPBoJjfwfkt6H6g4aad88+4BcY5BNX+P1zQAql7e/ Yhmh/hLHFeM58ZbGY4FEGAHBMaQJiqyYOZJsWoB3ILtKVSpclL0m2v5CkP2mA4HX XYD7rR/lB7OabLHOvWAaRF0ExvQjHkYemVIHK0dH6CSzOYNpOd8aRtaEx4iXAigE RWT0aQ6GejYwiLhSmOvhzHKRUVQ7A51F5hWofW1IbEkG3OjBx1/4k26TBjcCAwEA AQKCAgABseW8zf+TyrTZX4VeRX008Q0n4UA6R4HgClnBDz12T94Gge8RHJdYE+k8 XImXLFTkWA8uyEispSF7wbnndLDH42D1RmVarEHnsb1ipv6WOy7HGFLqvThBWluX 783yH4oe/Dw3JcIIcYcbl9hNjD+iR9jUu8eG057w8SU21wWEPiOHmVntt80woNO6 ZKeD2mRCGZPy260H474O2ctE1LUsXWYMhx857HpusvTEs90r5mXDcetjpjo8cq7n sDukLm1q9m3hCNvbezQ21UxjmHnpK/XDXDAohdMWG/ZBMmz2ilanvhITVieGLdAV ehBi8SEqqxkD5hd9l5lxTjbRmUrdRZilnUKqup9WcOTQYeAZ2WAazyYuFqWAwSf+ dU+SzMTG+7ts9y4RbnWL9H6hN2GWMeNdLRVqE4aECMv7kAIJZ2u6VyNXSEoVueBM CJ7CU075QgxNL1REDWRBaUaflBhdwQFnMXBULw2E01KZFmQvZLe06SI/xjkB7oGU HdqWRDx0YP8lrFG35ukA2t+EswJxcbZHsagEdrz0jjz0a87vjgHnff1XpowhZU6M 4OgtQpoM4t4O7xg/sl80c0WwVvsOHVkGwUARCfZ4F2fXnocpYOCWQQbsA/SH/qJ8 l+ChM4XkBNzKAUtpwkozqisKURJKTAJyeuAKD4fXRX/IwcPUYQKCAQEAyp1iiuTX pXzDso+3WPxLr3kwYJSUxpxSP4EjZZvzJoVflFBttUOoLURPEMrK5tEqWHqRrJto 73s3yQt4xWUtUql5eCB69nIVjseRhsbXjNzMIC41u65aflfIqQztHzF2gdFMZh3I gBp87CzKHSf83ToN3QZtQxIvuPdYdxDIjCMHc5hgRSLNKGhKXs1qWA76ASGNwQKW 7nUflWfDG3yZ7sWtmz7T2djz2zsmmzppCRRVjHAxQWZ+TxW+KsBOpGzgNvteUese ZK2ARc6lLSdgS74J5U6j07dOzQZ4eVC/OPHAIbPZxJAZ7/waP7YM+h+ohU+G8kXL KevnXjsC2oa/FwKCAQEAteoHugnwXvl9VyPceGQeffmQIq095CoD35UVlq60yR/9 zgGN8mrXuEgGyydCYrK0/pUYb1pQhk5Xy1D6t5ou44uYlGuksWDqquRwgl7qMMVE 0GAwm+3wUmz7u5XD3uEJaGWV+gbvg8Hbvl3V/MzjlI4caAZ3lcNaX/Jf3xG6Gyfi So0iQzVMN6NR7m+I32YFB3jxu9PlzUTEj+9SCHuERFAozuzwjdLwiYjNMzv0zPWj v3ERO2mX6PE6yN1XkBsCGGG9qVz/ZzvKOz8Dl4TryY0a5eg4QUEZ3nUlnpq9/8M3 xcN6M2yK8XLbTmVhSHX2J5nVI3s+BTbVHBoO0edl4QKCAQBcmMbTUThYkgdh0Jpr WYpBXHJGgUDo78IK8bq6kiXygdunjYZF4/C1F1XHB9bo28itfP6cUr4HTFm3UL3W AKJQ99DinH11qbe+c+hHHxKddr73Kgc2ib0jpny2/YhUzCcrtvpiZNQf73sN+H46 Cu9eL0zsqSZAE8ypjKjqaUot+UhLhOTiU8BM6jSq1Nf3/Ig3Ah2lishtnCtd/XjG VBCJdeAcZf8tvR/dHlBLestL8fYS46cvC2dIP1iUcyS9smBZ4FE/wOM4Aa7wuDr2 wtsYYnZlTKZEeK7TtlRSpRtvK9Sx0l8AnRatfZqFaW7O1K8QlcLHcCwkMYKgpvlr 407rAoIBAQCi5nqa1xGgCux53wwr5wQDLTssQlS8//7N9ZQKhlIwFOzT0EKLha+9 PwqOW46wEXXQ0DS8anTXgEpQMCkDxxcb/sLYjfhCOxaJh91Ucahnmg+ARdLhn1Xo id124qsu5/fju6xs5E8RfsTHmQHpypQ1UHkRklD+FJzWdJXzjM1KShHzTqUS6CRj YmYZDVnVK2dvhJd76knL4jve5KFiJTGRdvLEMhtL9Uwe7RlMOvGBpKpI4fhbarh1 CafpfYRO8FCVAtmzUysHB9yV51zRD1+R8kDXBndxv9lpgx/4AnwID4nfF6hTamyV wJOwhUpzd+bBGZlql483Xh3Cd3cz8nIhAoIBACs/XIDpXojtWopHXZReNwhqPC1D q3rjpPrZ8uqDu0Z/iTTO9OSvYaMBTVjXQ7w8T3X3ilMr45kpsHx0TQeh3Jbjy459 S9z+6MtSIM0fbpYBEfa7sirDQM/ZlgZjm7vq/4lBVFGFIw7vxu4m/G0oHtihWRKh ClGG1Ypm00srgWihhjtRn8hfnLqCi4t9xxW1q8Te01Gem8H0nfNKfs5V8O4cKIZa izrfne/1Fto1khYFTlP6XdVHPjvl2/qX2WUz4G+2eNWGQVghC70cuV8kiFYlEXVp a6w2oSx8jo+5qRZrMlUQP5bE7dOBvZuoBmEi/FVfRYuFdxSZ3H2VAZKRgC4= -----END RSA PRIVATE KEY----- ================================================ FILE: tests/certs/client.cert.pem ================================================ -----BEGIN CERTIFICATE----- MIIEAzCCAuugAwIBAgIUPfej8IQ/5bCrihqWImrq2vKPOq0wDQYJKoZIhvcNAQEL BQAwgaMxCzAJBgNVBAYTAkNBMRAwDgYDVQQIDAdPbnRhcmlvMRAwDgYDVQQHDAdU b3JvbnRvMRgwFgYDVQQKDA9NYWdpY1N0YWNrIEluYy4xFjAUBgNVBAsMDWFzeW5j cGcgdGVzdHMxHzAdBgNVBAMMFmFzeW5jcGcgdGVzdCBjbGllbnQgQ0ExHTAbBgkq hkiG9w0BCQEWDmhlbGxvQG1hZ2ljLmlvMB4XDTIxMDgwOTIxNTA1MloXDTMyMDEw NDIxNTA1MlowgZUxCzAJBgNVBAYTAkNBMRAwDgYDVQQIDAdPbnRhcmlvMRAwDgYD VQQHDAdUb3JvbnRvMRgwFgYDVQQKDA9NYWdpY1N0YWNrIEluYy4xFjAUBgNVBAsM DWFzeW5jcGcgdGVzdHMxETAPBgNVBAMMCHNzbF91c2VyMR0wGwYJKoZIhvcNAQkB Fg5oZWxsb0BtYWdpYy5pbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB AJjiP9Ik/KRRLK9GMvoH8m1LO+Gyrr8Gz36LpmKJMR/PpwTL+1pOkYSGhOyT3Cw9 /kWWLJRCvYqKgFtYtbr4S6ReGm3GdSVW+sfVRYDrRQZLPgQSPeq25g2v8UZ63Ota lPAyUPUZKpxyWz8PL77lV8psb9yv14yBH2kv9BbxKPksWOU8p8OCn1Z3WFFl0ItO nzMvCp5os+xFrt4SpoRGTx9x4QleY+zrEsYZtmnV4wC+JuJkNw4fuCdrX5k7dghs uZkcsAZof1nMdYsYiazeDfQKZtJqh5kO7mpwvCudKUWaLJJUwiQA87BwSlnCd/Hh TZDbC+zeFNjTS49/4Q72xVECAwEAAaM7MDkwHwYDVR0jBBgwFoAUi1jMmAisuOib mHIE2n0W2WnnaL0wCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwDQYJKoZIhvcNAQEL BQADggEBACbnp5oOp639ko4jn8axF+so91k0vIcgwDg+NqgtSRsuAENGumHAa8ec YOks0TCTvNN5E6AfNSxRat5CyguIlJ/Vy3KbkkFNXcCIcI/duAJvNphg7JeqYlQM VIJhrO/5oNQMzzTw8XzTHnciGbrbiZ04hjwrruEkvmIAwgQPhIgq4H6umTZauTvk DEo7uLm7RuG9hnDyWCdJxLLljefNL/EAuDYpPzgTeEN6JAnOu0ULIbpxpJKiYEId 8I0U2n0I2NTDOHmsAJiXf8BiHHmpK5SXFyY9s2ZuGkCzvmeZlR81tTXmHZ3v1X2z 8NajoAZfJ+QD50DrbF5E00yovZbyIB4= -----END CERTIFICATE----- ================================================ FILE: tests/certs/client.csr.pem ================================================ -----BEGIN CERTIFICATE REQUEST----- MIIC2zCCAcMCAQAwgZUxCzAJBgNVBAYTAkNBMRAwDgYDVQQIDAdPbnRhcmlvMRAw DgYDVQQHDAdUb3JvbnRvMRgwFgYDVQQKDA9NYWdpY1N0YWNrIEluYy4xFjAUBgNV BAsMDWFzeW5jcGcgdGVzdHMxETAPBgNVBAMMCHNzbF91c2VyMR0wGwYJKoZIhvcN AQkBFg5oZWxsb0BtYWdpYy5pbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC ggEBAJjiP9Ik/KRRLK9GMvoH8m1LO+Gyrr8Gz36LpmKJMR/PpwTL+1pOkYSGhOyT 3Cw9/kWWLJRCvYqKgFtYtbr4S6ReGm3GdSVW+sfVRYDrRQZLPgQSPeq25g2v8UZ6 3OtalPAyUPUZKpxyWz8PL77lV8psb9yv14yBH2kv9BbxKPksWOU8p8OCn1Z3WFFl 0ItOnzMvCp5os+xFrt4SpoRGTx9x4QleY+zrEsYZtmnV4wC+JuJkNw4fuCdrX5k7 dghsuZkcsAZof1nMdYsYiazeDfQKZtJqh5kO7mpwvCudKUWaLJJUwiQA87BwSlnC d/HhTZDbC+zeFNjTS49/4Q72xVECAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4IBAQCG irI2ph09V/4BMe6QMhjBFUatwmTa/05PYGjvT3LAhRzEb3/o/gca0XFSAFrE6zIY DsgMk1c8aLr9DQsn9cf22oMFImKdnIZ3WLE9MXjN+s1Bjkiqt7uxDpxPo/DdfUTQ RQC5i/Z2tn29y9K09lEjp35ZhPp3tOA0V4CH0FThAjRR+amwaBjxQ7TTSNfoMUd7 i/DrylwnNg1iEQmYUwJYopqgxtwseiBUSDXzEvjFPY4AvZKmEQmE5QkybpWIfivt 1kmKhvKKpn5Cb6c0D3XoYqyPN3TxqjH9L8R+tWUCwhYJeDZj5DumFr3Hw/sx8tOL EctyS6XfO3S2KbmDiyv8 -----END CERTIFICATE REQUEST----- ================================================ FILE: tests/certs/client.key.pem ================================================ -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAmOI/0iT8pFEsr0Yy+gfybUs74bKuvwbPfoumYokxH8+nBMv7 Wk6RhIaE7JPcLD3+RZYslEK9ioqAW1i1uvhLpF4abcZ1JVb6x9VFgOtFBks+BBI9 6rbmDa/xRnrc61qU8DJQ9RkqnHJbPw8vvuVXymxv3K/XjIEfaS/0FvEo+SxY5Tyn w4KfVndYUWXQi06fMy8Knmiz7EWu3hKmhEZPH3HhCV5j7OsSxhm2adXjAL4m4mQ3 Dh+4J2tfmTt2CGy5mRywBmh/Wcx1ixiJrN4N9Apm0mqHmQ7uanC8K50pRZosklTC JADzsHBKWcJ38eFNkNsL7N4U2NNLj3/hDvbFUQIDAQABAoIBAAIMVeqM0E2rQLwA ZsJuxNKuBVlauXiZsMHzQQFk8SGJ+KTZzr5A+zYZT0KUIIj/M57fCi3aTwvCG0Ie CCE/HlRPZm8+D2e2qJlwxAOcI0qYS3ZmgCna1W4tgz/8eWU1y3UEV41RDv8VkR9h JrSaAfkWRtFgEbUyLaeNGuoLxQ7Bggo9zi1/xDJz/aZ/y4L4y8l1xs2eNVmbRGnj mPr1daeYhsWgaNiT/Wm3CAxvykptHavyWSsrXzCp0bEw6fAXxBqkeDFGIMVC9q3t ZRFtqMHi9i7SJtH1XauOC6QxLYgSEmNEie1JYbNx2Zf4h2KvSwDxpTqWhOjJ/m5j /NSkASECgYEAyHQAqG90yz5QaYnC9lgUhGIMokg9O3LcEbeK7IKIPtC9xINOrnj6 ecCfhfc1aP3wQI+VKC3kiYerfTJvVsU5CEawBQSRiBY/TZZ7hTR7Rkm3s4xeM+o6 2zADdVUwmTVYwu0gUKCeDKO4iD8Uhh8J54JrKUejuG50VWZQWGVgqo0CgYEAwz+2 VdYcfuQykMA3jQBnXmMMK92/Toq6FPDgsa45guEFD6Zfdi9347/0Ipt+cTNg0sUZ YBLOnNPwLn+yInfFa88Myf0UxCAOoZKfpJg/J27soUJzpd/CGx+vaAHrxMP6t/qo JAGMBIyOoqquId7jvErlC/sGBk/duya7IdiT1tUCgYBuvM8EPhaKlVE9DJL9Hmmv PK94E2poZiq3SutffzkfYpgDcPrNnh3ZlxVJn+kMqITKVcfz226On7mYP32MtQWt 0cc57m0rfgbYqRJx4y1bBiyK7ze3fGWpYxv1/OsNKJBxlygsAp9toiC2fAqtkYYa NE1ZD6+dmr9/0jb+rnq5nQKBgQCtZvwsp4ePOmOeItgzJdSoAxdgLgQlYRd6WaN0 qeLx1Z6FE6FceTPk1SmhQq+9IYAwMFQk+w78QU3iPg6ahfyTjsMw8M9sj3vvCyU1 LPGJt/34CehjvKHLLQy/NlWJ3vPgSYDi2Wzc7WgQF72m3ykqpOlfBoWHPY8TE4bG vG4wMQKBgFSq2GDAJ1ovBl7yWYW7w4SM8X96YPOff+OmI4G/8+U7u3dDM1dYeQxD 7BHLuvr4AXg27LC97u8/eFIBXC1elbco/nAKE1YHj2xcIb/4TsgAqkcysGV08ngi dULh3q0GpTYyuELZV4bfWE8MjSiGAH+nuMdXYDGuY2QnBq8MdSOH -----END RSA PRIVATE KEY----- ================================================ FILE: tests/certs/client.key.protected.pem ================================================ -----BEGIN RSA PRIVATE KEY----- Proc-Type: 4,ENCRYPTED DEK-Info: AES-256-CBC,B222CD7D00828606A07DBC489D400921 LRHsNGUsD5bG9+x/1UlzImN0rqEF10sFPBmxKeQpXQ/hy4iR+X/Gagoyagi23wOn EZf0sCLJx95ixG+4fXJDX0jgBtqeziVNS4FLWHIuf3+blja8nf4tkmmH9pF8jFQ0 i1an3TP6KRyDKa17gioOdtSsS51BZmPkp3MByJQsrMhyB0txEUsGtUMaBTYmVN/5 uYHf9MsmfcfQy30nt2t6St6W82QupHHMOx5xyhPJo8cqQncZC7Dwo4hyDV3h3vWn UjaRZiEMmQ3IgCwfJd1VmMECvrwXd/sTOXNhofWwDQIqmQ3GGWdrRnmgD863BQT3 V8RVyPLkutOnrZ/kiMSAuiXGsSYK0TV8F9TaP/abLob4P8jbKYLcuR7ws3cu1xBl XWt9RALxGPUyHIy+BWLXJTYL8T+TVJpiKsAGCQB54j8VQBSArwFL4LnzdUu1txe2 qa6ZEwt4q6SEwOTJpJWz3oJ1j+OTsRCN+4dlyo7sEZMeyTRp9nUzwulhd+fOdIhY 2UllMG71opKfNxZzEW7lq6E/waf0MmxwjUJmgwVO218yag9oknHnoFwewF42DGY7 072h23EJeKla7sI+MAB18z01z6C/yHWXLybOlXaGqk6zOm3OvTUFnUXtKzlBO2v3 FQwrOE5U/VEyQkNWzHzh4j4LxYEL9/B08PxaveUwvNVGn9I3YknE6uMfcU7VuxDq +6bgM6r+ez+9QLFSjH/gQuPs2DKX0h3b9ppQNx+MANX0DEGbGabJiBp887f8pG6Q tW0i0+rfzYz3JwnwIuMZjYz6qUlP4bJMEmmDfod3fbnvg3MoCSMTUvi1Tq3Iiv4L GM5/YNkL0V3PhOI686aBfU7GLGXQFhdbQ9xrSoQRBmmNBqTCSf+iIEoTxlBac8GQ vSzDO+A+ovBP36K13Yn7gzuN/3PLZXH2TZ8t2b/OkEXOciH5KbycGHQA7gqxX1P4 J55gpqPAWe8e7wKheWj3BMfmbWuH4rpiEkrLpqbTSfTwIKqplk253chmJj5I82XI ioFLS5vCi9JJsTrQ720O+VQPVB5xeA80WL8NxamWQb/KkvVnb4dTmaV30RCgLLZC tuMx8YSW71ALLT15qFB2zlMDKZO1jjunNE71BUFBPIkTKEOCyMAiF60fFeIWezxy kvBBOg7+MTcZNeW110FqRWNGr2A5KYFN15g+YVpfEoF26slHisSjVW5ndzGh0kaQ sIOjQitA9JYoLua7sHvsr6H5KdCGjNxv7O7y8wLGBVApRhU0wxZtbClqqEUvCLLP UiLDp9L34wDL7sGrfNgWA4UuN29XQzTxI5kbv/EPKhyt2oVHLqUiE+eGyvnuYm+X KqFi016nQaxTU5Kr8Pl0pSHbJMLFDWLSpsbbTB6YJpdEGxJoj3JB3VncOpwcuK+G xZ1tV2orPt1s/6m+/ihzRgoEkyLwcLRPN7ojgD/sqS679ZGf1IkDMgFCQe4g0UWm Fw7v816MNCgypUM5hQaU+Jp8vSlEc29RbrdSHbcxrKj/xPCLWrAbvmI5tgonKmuJ J1LW8AXyh/EUp/uUh++jqVGx+8pFfcmJw6V6JrJzQ7HMlakkry7N1eAGrIJGtYCW -----END RSA PRIVATE KEY----- ================================================ FILE: tests/certs/client_ca.cert.pem ================================================ -----BEGIN CERTIFICATE----- MIIEKTCCAxGgAwIBAgIUKmL8tfNS9LIB6GLB9RpZpTyk3uIwDQYJKoZIhvcNAQEL BQAwgaMxCzAJBgNVBAYTAkNBMRAwDgYDVQQIDAdPbnRhcmlvMRAwDgYDVQQHDAdU b3JvbnRvMRgwFgYDVQQKDA9NYWdpY1N0YWNrIEluYy4xFjAUBgNVBAsMDWFzeW5j cGcgdGVzdHMxHzAdBgNVBAMMFmFzeW5jcGcgdGVzdCBjbGllbnQgQ0ExHTAbBgkq hkiG9w0BCQEWDmhlbGxvQG1hZ2ljLmlvMB4XDTIxMDgwOTIxNDQxM1oXDTQxMDgw NDIxNDQxM1owgaMxCzAJBgNVBAYTAkNBMRAwDgYDVQQIDAdPbnRhcmlvMRAwDgYD VQQHDAdUb3JvbnRvMRgwFgYDVQQKDA9NYWdpY1N0YWNrIEluYy4xFjAUBgNVBAsM DWFzeW5jcGcgdGVzdHMxHzAdBgNVBAMMFmFzeW5jcGcgdGVzdCBjbGllbnQgQ0Ex HTAbBgkqhkiG9w0BCQEWDmhlbGxvQG1hZ2ljLmlvMIIBIjANBgkqhkiG9w0BAQEF AAOCAQ8AMIIBCgKCAQEAptRYfxKiWExfZguQDva53bIqYa4lJwZA86Qu0peBUcsd E6zyHNgVv4XSMim1FH12KQ4KPKuQAcVqRMCRAHqB96kUfWQqF//fLajr0umdzcbx +UTgNux8TkScTl9KNAxhiR/oOGbKFcNSs4raaG8puwwEN66uMhoKk2pN2NwDVfHa bTekJ3jouTcTCnqCynx4qwI4WStJkuW4IPCmDRVXxOOauT7YalElYLWYtAOqGEvf noDK2Imhc0h6B5XW8nI54rVCXWwhW1v3RLAJGP+LwSy++bf08xmpHXdKkAj5BmUO QwJRiJ33Xa17rmi385egx8KpqV04YEAPdV1Z4QM6PQIDAQABo1MwUTAdBgNVHQ4E FgQUi1jMmAisuOibmHIE2n0W2WnnaL0wHwYDVR0jBBgwFoAUi1jMmAisuOibmHIE 2n0W2WnnaL0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAifNE ZLZXxECp2Sl6jCViZxgFf2+OHDvRORgI6J0heckYyYF/JHvLaDphh6TkSJAdT6Y3 hAb7jueTMI+6RIdRzIjTKCGdJqUetiSfAbnQyIp2qmVqdjeFoXTvQL7BdkIE+kOW 0iomMqDB3czTl//LrgVQCYqKM0D/Ytecpg2mbshLfpPxdHyliCJcb4SqfdrDnKoV HUduBjOVot+6bkB5SEGCrrB4KMFTzbAu+zriKWWz+uycIyeVMLEyhDs59vptOK6e gWkraG43LZY3cHPiVeN3tA/dWdyJf9rgK21zQDSMB8OSH4yQjdQmkkvRQBjp3Fcy w2SZIP4o9l1Y7+hMMw== -----END CERTIFICATE----- ================================================ FILE: tests/certs/client_ca.cert.srl ================================================ 3DF7A3F0843FE5B0AB8A1A96226AEADAF28F3AAD ================================================ FILE: tests/certs/client_ca.key.pem ================================================ -----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAptRYfxKiWExfZguQDva53bIqYa4lJwZA86Qu0peBUcsdE6zy HNgVv4XSMim1FH12KQ4KPKuQAcVqRMCRAHqB96kUfWQqF//fLajr0umdzcbx+UTg Nux8TkScTl9KNAxhiR/oOGbKFcNSs4raaG8puwwEN66uMhoKk2pN2NwDVfHabTek J3jouTcTCnqCynx4qwI4WStJkuW4IPCmDRVXxOOauT7YalElYLWYtAOqGEvfnoDK 2Imhc0h6B5XW8nI54rVCXWwhW1v3RLAJGP+LwSy++bf08xmpHXdKkAj5BmUOQwJR iJ33Xa17rmi385egx8KpqV04YEAPdV1Z4QM6PQIDAQABAoIBABQrKcO7CftoyEO6 9CCK/W9q4arLddxg6itKVwrInC66QnqlduO7z+1GjWHZHvYqMMXH17778r30EuPa 7+zB4sKBI2QBXwFlwqJvgIsQCS7edVRwWjbpoiGIM+lZpcvjD0uXmuhurNGyumXQ TJVBkyb0zfG5YX/XHB40RNMJzjFuiMPDLVQmmDE//FOuWqBG88MgJP9Ghk3J7wA2 JfDPavb49EzOCSh74zJWP7/QyybzF3ABCMu4OFkaOdqso8FS659XI55QReBbUppu FRkOgao1BclJhbBdrdtLNjlETM82tfVgW56vaIrrU2z7HskihEyMdB4c+CYbBnPx QqIhkhUCgYEA0SLVExtNy5Gmi6/ZY9tcd3QIuxcN6Xiup+LgIhWK3+GIoVOPsOjN 27dlVRINPKhrCfVbrLxUtDN5PzphwSA2Qddm4jg3d5FzX+FgKHQpoaU1WjtRPP+w K+t6W/NbZ8Rn4JyhZQ3Yqj264NA2l3QmuTfZSUQ5m4x7EUakfGU7G1sCgYEAzDaU jHsovn0FedOUaaYl6pgzjFV8ByPeT9usN54PZyuzyc+WunjJkxCQqD88J9jyG8XB 3V3tQj/CNbMczrS2ZaJ29aI4b/8NwBNR9e6t01bY3B90GJi8S4B4Hf8tYyIlVdeL tCC4FCZhvl4peaK3AWBj4NhjvdB32ThDXSGxLEcCgYEAiA5tKHz+44ziGMZSW1B+ m4f1liGtf1Jv7fD/d60kJ/qF9M50ENej9Wkel3Wi/u9ik5v4BCyRvpouKyBEMGxQ YA1OdaW1ECikMqBg+nB4FR1x1D364ABIEIqlk+SCdsOkANBlf2S+rCJ0zYUnvuhl uOHIjo3AHJ4MAnU+1V7WUTkCgYBkMedioc7U34x/QJNR3sY9ux2Xnh2zdyLNdc+i njeafDPDMcoXhcoJERiYpCYEuwnXHIlI7pvJZHUKWe4pcTsI1NSfIk+ki7SYaCJP kyLQTY0rO3d/1fiU5tyIgzomqIs++fm+kEsg/8/3UkXxOyelUkDPAfy2FgGnn1ZV 7ID8YwKBgQCeZCapdGJ6Iu5oYB17TyE5pLwb+QzaofR5uO8H4pXGVQyilKVCG9Dp GMnlXD7bwXPVKa8Icow2OIbmgrZ2mzOo9BSY3BlkKbpJDy7UNtAhzsHHN5/AEk8z YycWQtMiXI+cRsYO0eyHhJeSS2hX+JTe++iZX65twV53agzCHWRIbg== -----END RSA PRIVATE KEY----- ================================================ FILE: tests/certs/gen.py ================================================ import datetime import os from cryptography import x509 from cryptography.hazmat import backends from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509 import oid def _new_cert(issuer=None, is_issuer=False, serial_number=None, **subject): backend = backends.default_backend() private_key = rsa.generate_private_key( public_exponent=65537, key_size=4096, backend=backend ) public_key = private_key.public_key() subject = x509.Name( [ x509.NameAttribute(getattr(oid.NameOID, key.upper()), value) for key, value in subject.items() ] ) builder = ( x509.CertificateBuilder() .subject_name(subject) .public_key(public_key) .serial_number(serial_number or int.from_bytes(os.urandom(8), "big")) ) if issuer: issuer_cert, signing_key = issuer builder = ( builder.issuer_name(issuer_cert.subject) .not_valid_before(issuer_cert.not_valid_before) .not_valid_after(issuer_cert.not_valid_after) ) aki_ext = x509.AuthorityKeyIdentifier( key_identifier=issuer_cert.extensions.get_extension_for_class( x509.SubjectKeyIdentifier ).value.digest, authority_cert_issuer=[x509.DirectoryName(issuer_cert.subject)], authority_cert_serial_number=issuer_cert.serial_number, ) else: signing_key = private_key builder = ( builder.issuer_name(subject) .not_valid_before( datetime.datetime.today() - datetime.timedelta(days=1) ) .not_valid_after( datetime.datetime.today() + datetime.timedelta(weeks=1000) ) ) aki_ext = x509.AuthorityKeyIdentifier.from_issuer_public_key( public_key ) if is_issuer: builder = ( builder.add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, ) .add_extension( x509.KeyUsage( digital_signature=False, content_commitment=False, key_encipherment=False, data_encipherment=False, key_agreement=False, key_cert_sign=True, crl_sign=True, encipher_only=False, decipher_only=False, ), critical=False, ) .add_extension( x509.SubjectKeyIdentifier.from_public_key(public_key), critical=False, ) .add_extension( aki_ext, critical=False, ) ) else: builder = ( builder.add_extension( x509.KeyUsage( digital_signature=True, content_commitment=False, key_encipherment=True, data_encipherment=False, key_agreement=False, key_cert_sign=False, crl_sign=False, encipher_only=False, decipher_only=False, ), critical=False, ) .add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ) .add_extension( x509.ExtendedKeyUsage([oid.ExtendedKeyUsageOID.SERVER_AUTH]), critical=False, ) .add_extension( x509.SubjectAlternativeName([x509.DNSName("localhost")]), critical=False, ) .add_extension( x509.SubjectKeyIdentifier.from_public_key(public_key), critical=False, ) .add_extension( aki_ext, critical=False, ) ) certificate = builder.sign( private_key=signing_key, algorithm=hashes.SHA256(), backend=backend, ) return certificate, private_key def _write_cert(path, cert_key_pair, password=None): certificate, private_key = cert_key_pair if password: encryption = serialization.BestAvailableEncryption(password) else: encryption = serialization.NoEncryption() with open(path + ".key.pem", "wb") as f: f.write( private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=encryption, ) ) with open(path + ".cert.pem", "wb") as f: f.write( certificate.public_bytes( encoding=serialization.Encoding.PEM, ) ) def new_ca(path, **subject): cert_key_pair = _new_cert(is_issuer=True, **subject) _write_cert(path, cert_key_pair) return cert_key_pair def new_cert( path, ca_cert_key_pair, password=None, is_issuer=False, **subject ): cert_key_pair = _new_cert( issuer=ca_cert_key_pair, is_issuer=is_issuer, **subject ) _write_cert(path, cert_key_pair, password) return cert_key_pair def new_crl(path, issuer, cert): issuer_cert, signing_key = issuer revoked_cert = ( x509.RevokedCertificateBuilder() .serial_number(cert[0].serial_number) .revocation_date(datetime.datetime.today()) .build() ) builder = ( x509.CertificateRevocationListBuilder() .issuer_name(issuer_cert.subject) .last_update(datetime.datetime.today()) .next_update(datetime.datetime.today() + datetime.timedelta(days=1)) .add_revoked_certificate(revoked_cert) ) crl = builder.sign(private_key=signing_key, algorithm=hashes.SHA256()) with open(path + ".crl.pem", "wb") as f: f.write(crl.public_bytes(encoding=serialization.Encoding.PEM)) def main(): ca = new_ca( "ca", country_name="CA", state_or_province_name="Ontario", locality_name="Toronto", organization_name="MagicStack Inc.", organizational_unit_name="asyncpg tests", common_name="asyncpg test root ca", email_address="hello@magic.io", ) server = new_cert( "server", ca, country_name="CA", state_or_province_name="Ontario", organization_name="MagicStack Inc.", organizational_unit_name="asyncpg tests", common_name="localhost", email_address="hello@magic.io", serial_number=4096, ) new_crl('server', ca, server) if __name__ == "__main__": main() ================================================ FILE: tests/certs/server.cert.pem ================================================ -----BEGIN CERTIFICATE----- MIIG5jCCBM6gAwIBAgICEAAwDQYJKoZIhvcNAQELBQAwgaExCzAJBgNVBAYTAkNB MRAwDgYDVQQIDAdPbnRhcmlvMRAwDgYDVQQHDAdUb3JvbnRvMRgwFgYDVQQKDA9N YWdpY1N0YWNrIEluYy4xFjAUBgNVBAsMDWFzeW5jcGcgdGVzdHMxHTAbBgNVBAMM FGFzeW5jcGcgdGVzdCByb290IGNhMR0wGwYJKoZIhvcNAQkBFg5oZWxsb0BtYWdp Yy5pbzAeFw0yNDEwMTYxNzIzNTZaFw00MzEyMTcxNzIzNTZaMIGEMQswCQYDVQQG EwJDQTEQMA4GA1UECAwHT250YXJpbzEYMBYGA1UECgwPTWFnaWNTdGFjayBJbmMu MRYwFAYDVQQLDA1hc3luY3BnIHRlc3RzMRIwEAYDVQQDDAlsb2NhbGhvc3QxHTAb BgkqhkiG9w0BCQEWDmhlbGxvQG1hZ2ljLmlvMIICIjANBgkqhkiG9w0BAQEFAAOC Ag8AMIICCgKCAgEA3F017q/obCM1SsHY5dFz72pFgVMhBIZ6kdIInbFv7RmEykZz ubbJnrgwgYDO5FKGUNO+a80AbjIvBrtPtXs9Ip/QDg0jqgw/MOADCxCzYnAQ2Ew2 y1PfspGtdPhLNTmrO8+AxU2XmjsYY0+ysgUQQttOs9hJ79pIsKGBEES8g9oJTiIf tKgCxCIuhiZC+AgjeIQZUB9ccifmOGrCJYrD6LBuNGoQNW2/ykqjuHE8219dv1hV do8azcp/WmejjQguZyU3S/AofnyyNE24rWpXbbFs+9FFaUXd8g/fWCwrRmcXpOaE lvkmMZyuT9kuglHsvpzzGGNSUpvVoPfldk/4JY/kJrA2G5pgTX6mGRYGEN0jmlCa yg/ZFn36G0mA5ZBH4Qln+lKUSjJH8bhlFXvXlE3Mc34OCdOAp1TRfOT/qCRKo9A5 KCjVOvG5MAKE8TZnTFLCSx5gK/EdQ2iV7Sm3aVc2P4eEJh+nvv1LDVLQEAak6U+u sZN5+Wnu7wDKSlh80vTTtoqls5Uo3gIxHYnqX5Fj6nwCzGjjXISNE4OKZLuk3can mciEES3plUrut+O6a2JWiDoCrwX4blYXhtL92Xaer/Mk1TSf2JsmL6pODoapsA0S CHtpcgoodxdKriy1qUGsiNlPNVWjASGyKXoEZdv49wyoZuysudl1aS1w42UCAwEA AaOCAUEwggE9MAsGA1UdDwQEAwIFoDAMBgNVHRMBAf8EAjAAMBMGA1UdJQQMMAoG CCsGAQUFBwMBMBQGA1UdEQQNMAuCCWxvY2FsaG9zdDAdBgNVHQ4EFgQUO/cXg1uX 2oHZodbw6F3/HakLdaQwgdUGA1UdIwSBzTCByoAUhGQbAW97KXQs68Z3efEj55zs c4WhgaekgaQwgaExCzAJBgNVBAYTAkNBMRAwDgYDVQQIDAdPbnRhcmlvMRAwDgYD VQQHDAdUb3JvbnRvMRgwFgYDVQQKDA9NYWdpY1N0YWNrIEluYy4xFjAUBgNVBAsM DWFzeW5jcGcgdGVzdHMxHTAbBgNVBAMMFGFzeW5jcGcgdGVzdCByb290IGNhMR0w GwYJKoZIhvcNAQkBFg5oZWxsb0BtYWdpYy5pb4IICJCUmtkcj2MwDQYJKoZIhvcN AQELBQADggIBAD4Ti52nEttUNay+sqqbDLtnSyMRsJI8agPqiHz6bYifSf530rlh qlHYUY5tgfrd8yDZNIe9Ib7Q1WQjgR8c/T9SoFnLl/tff1CVOAYQ/ffCZGTdBOSc KfdKEEvObWxWsqv31ZAMWVzfPsF7rwbTbZ8YdH2CNjxbZxrSEn2IrjplsoP5WMsE 6t7Q+J5wpi2yiEI9PoY2wH5WBB8ONWvZfj9r6OrczlTEZ+L6eiip5kMiw5R9EVt6 ju2aMWqbZTI49Mu/qvXRAkwYvX7mrhuW/4mPHOW/zSnN7hOyjntx1fdnpPD5BTT6 CoJ7nhWgnntw2kk2V9UBCYpVeqidDRrs+nr1xSpduuM1ve3SDkIpd6EGEUqZJ12s 5xpCUFK67atCZOXbJXqanm+3N9kbqYuwkWoqnPjOfMYW7oABmUy8elVGGwTuiTI0 sXS3aQJ+Bm7oqSXrIxUTjOUUaYNhhaqZdXaO/29vI2+i975Pt1ZLLPUkp0hsUgTT kryN02TlNTxxQafTWad6YdzyrwvMpV7vxf7JQkOKRwLinqLCDVxjBt66O9mLIpQF WIfWQG+X4sgobB0NTtBWeGkrIgnhUtsT0ibVm4JAC1cbxdLOq2dfcURC8UFWJXok yFr/uaDZiKKbUFXbalZwnx6H6ucfl5No3hheexadyIbPNcHhFJ9zGXot -----END CERTIFICATE----- ================================================ FILE: tests/certs/server.crl.pem ================================================ -----BEGIN X509 CRL----- MIIDAjCB6wIBATANBgkqhkiG9w0BAQsFADCBoTELMAkGA1UEBhMCQ0ExEDAOBgNV BAgMB09udGFyaW8xEDAOBgNVBAcMB1Rvcm9udG8xGDAWBgNVBAoMD01hZ2ljU3Rh Y2sgSW5jLjEWMBQGA1UECwwNYXN5bmNwZyB0ZXN0czEdMBsGA1UEAwwUYXN5bmNw ZyB0ZXN0IHJvb3QgY2ExHTAbBgkqhkiG9w0BCQEWDmhlbGxvQG1hZ2ljLmlvFw0y NDEwMTcxNzIzNTZaFw0yNDEwMTgxNzIzNTZaMBUwEwICEAAXDTI0MTAxNzE3MjM1 NlowDQYJKoZIhvcNAQELBQADggIBAEVNX72KK6etoZQOXzPgd8ZJNrYcsOwjNZFL ZxC47uX+yrxjv7Wrrk4feyakFi5bL9n8/JMggcpxC6yxMQH/sdOZJ0BzKw3GUAxj m53i1GGO1lGdKH5a7uDPZVW362JwCVE81ROCdb1SL/yYmIwhD4w2bqjOQuI63Xe1 MDfVZBqcIwzzkA5PEjTSFQIsBcHU+rDrWggkz/XJh5alRav8Gnj7KTE8U1z5UeKV LUk8L8+ZLW6XlrTnyjOn3qT7sZw2C/R46GCyHWwT5tbLhJhm2u1EuX3Iids02vIP w9bYf7+Uu2lsse9TuFNXtW0UFLdvVezomHjNBCaMI/MIvG4wSWnAo5bTtlowzxSy 7rpQQYBebcl5somUAhHqs4dsxbEwCXMPDdapiXkhxR9R4nDvkfsgwyqIRsWsIEq6 PFjjRySNFUg5/vqhVQrg0hV7ygzXfd/kIlud3ZkKnli51TuFMWKD5sMN0r8ITLdG usoJQiF6G3ByLQBnsiQoHbipWkWTOKmfB/cfaPXdagPZH6rQmJeeNq0vBy6VqbFi 7D+BqABs+yIT6uJEEqyPGJttkUZP+0ziaK+DZF4MgJtiERtz2GjKMeh3h/YSqA27 8El6na7hPA3k1pANkaOaKuxZYzrPsl3P91ISGL6E0dgd6f9NZMOxbhfNKoDsBJnd Hjb3RTY4 -----END X509 CRL----- ================================================ FILE: tests/certs/server.key.pem ================================================ -----BEGIN RSA PRIVATE KEY----- MIIJKQIBAAKCAgEA3F017q/obCM1SsHY5dFz72pFgVMhBIZ6kdIInbFv7RmEykZz ubbJnrgwgYDO5FKGUNO+a80AbjIvBrtPtXs9Ip/QDg0jqgw/MOADCxCzYnAQ2Ew2 y1PfspGtdPhLNTmrO8+AxU2XmjsYY0+ysgUQQttOs9hJ79pIsKGBEES8g9oJTiIf tKgCxCIuhiZC+AgjeIQZUB9ccifmOGrCJYrD6LBuNGoQNW2/ykqjuHE8219dv1hV do8azcp/WmejjQguZyU3S/AofnyyNE24rWpXbbFs+9FFaUXd8g/fWCwrRmcXpOaE lvkmMZyuT9kuglHsvpzzGGNSUpvVoPfldk/4JY/kJrA2G5pgTX6mGRYGEN0jmlCa yg/ZFn36G0mA5ZBH4Qln+lKUSjJH8bhlFXvXlE3Mc34OCdOAp1TRfOT/qCRKo9A5 KCjVOvG5MAKE8TZnTFLCSx5gK/EdQ2iV7Sm3aVc2P4eEJh+nvv1LDVLQEAak6U+u sZN5+Wnu7wDKSlh80vTTtoqls5Uo3gIxHYnqX5Fj6nwCzGjjXISNE4OKZLuk3can mciEES3plUrut+O6a2JWiDoCrwX4blYXhtL92Xaer/Mk1TSf2JsmL6pODoapsA0S CHtpcgoodxdKriy1qUGsiNlPNVWjASGyKXoEZdv49wyoZuysudl1aS1w42UCAwEA AQKCAgAXD9TfxfPCXWzrsJ3NGhPSr9crpvzYRw/3cs5esn3O3Sd92SGuAz3WfoWV CAX0SdlaBs7xjo1yUDjbsNQGtNRmaz3lj+Ug8WcrlkYQl7mDnnbPgX+6h8HsI5LO SwM+mWpyN/p3Vkd8vJ0wx4Z2sFD4rjruV2m60FK11DEi+A6X6JmmCQGIcTeDjzrk jzHdrfxdqyAlt80qT+1Sui7XVE5sa7Uc3HzAcAaXr81dNXyeThIMPxJdS1y4F258 kkbA27pU0Rrtt5SFUvIoxyQsrJRkcSJsDYVWHxm7MNi5luXF2G7WXcmX2JCcCz8I MZJ3JlvAbGyEgOB8r2e2u5AoHEu7xjpjJ0/6smmig7LDe96uNpg6zDwS3xl6rAup qgwJ5TTwY8BydVOtDqe5Na8yqLtwMr0yA+k2Hz856mzCTJEOI9TaOq/jtq+n4AXW lkBai762oVKSKYCVJSK6eslTf2bAqjT3jakbgqJLKmMo5XvCnYUWWIve0RhQMNT4 0tiLCxKurYa7xPqgW26c/fEHvdBDrU1JAablcAjsW9sJ+KIlilK02M9DqF0RnBBI wK7Ql76ugsYbp8WBXkpFjMMyciMhqH8xJiyi7MuiCwpBGQwxBHHaX7f9OqDWOClR mVGjrZuk9oiI3waUjGG50SzLBlMbeIzMdXgRuM7fByq6DG0VgQKCAQEA8d2YCODh ApCM7GB/tmANfVQ0tnfxUT3ceEAOH7XkI+nz87Zv/1k6NOklCMi+nUwoGQfM5CxU NdWC0I7wI1ATdllPStUAJ4c8xtdEdlrLHBcGNvhYbbqMWRsNGITstnAx3tZ4X32H duhS5wfPE/X25YMN+8Dtm7jifEMqoCUV55iZxfYs+LXxQF03KVAJ5Ie5a1ac5UCz zzu9fbYSs70ByJsHWt4ZOsPkJVmkmuXzUPvr72otUYYSdju0PgbJqRoEyTbCh3HT zo0emKl8jj7oTSzVNjb6AaB6nsKco6wQLQSlaxBzo0j7TBRylVtG81CYjr5LFpp0 UQrHjLZnSTvC5wKCAQEA6T3yH6bFc9FcJGOW1jYozQ5y+NWkXv3MVFIf3IqPT76p rMEI6krmGUKi+otOaV2Axy36kOcbntzENMg++LPCe0SczK14+pwUrI91cp/Ega6K +/4sKvh8WDZhzVYkWs76UiRj7Ef4MvtsaPAcFN/Ek+fItDHFRoSGdm+vx+j3ZDxx tdRudTs0kYyhmdlM0kZTbXsmz37x6+45uO16s+D2lvX2PXM9Lve9z/Ti6nn9QvIF kM9ZmAU6epmMPsGKM9WOK/sTcPUnd3Ife9tmi3BRAAygDk6hFx67kAsc124oLeZ3 0CJGshA+50hBAL7wiybLrBMRzHrElzsicppVbn3p0wKCAQAldmRBI8vWYNtjFYNS lUghnHRZuvRG2CUY/xrw8HR415jwq9ZnH8PzRBV3adiUdqJTVjD3OqKEgCC1+x3Y 6mNJVoYAmkNe3ASe6+LvzhpdrHdK9maEAHwSpSz/Gj+r9m7TDDcy2zerRErq+/uo JNXsMMNutjBXiWiTRLgKfBQLfkh7MClBELVgec+8d2hA3IDszkqY+8+eDqvIF/aH noPzNYgLHBGeV48z9dGYKHvqlEq0F6cTVIfxhkfhv51msuAA5pl07z2WZadSkBX5 1maW5ZXUwukwbVHw20X12AXdYzXYAoFWzkwWOaiR18SClX47xd/NjXjswJWuBuay oi4LAoIBAQDirP0+nYmQAYwXIWJaVNBaWQyLoLXaS7XkzNuCLncQ/S9RYVkUui3d ptFVxUUzSVf6O0kkwjYpskxNL79jXPBJdGke0gidJktBWTq/Z15G2ibguCicqlnO MSvjrzAtwLGuWwdxfpBMm+TEJ3ZjIwWc6Mo5tZUP74PuXqTrGBI2LDgmiom/DQcN 3SrAplrukMJLyD/zsF/U9vTKMKHrZ1q/Y9Mn7XMszkB+dnSBhIUKJsQZ9CoSgCJR PCD8bIOv1IATZjOCt/7fKt5GNPf30/QkpCB5RxlvqsKGPwaMp9YMpcsTT/x82SUJ CUODQg3sbovKc838d+PPRf04e51DgMNZAoIBAQC2uiJjluIKRabFSeSfu4+I6cEY kXI0F65UAudFmyXVfaQbO9DR0Y4bWPDfXAUimRvxixEhSrSIBZ/itVxzhOvqZrl1 XRCZsTOVoz7Z8lcd8opxPBnWDk1m2nyajwPXp8ZLo67FG0bWbayVBBRxyvirrZjG PatRKMTyVLTCD+WlQiP4b4kShKdWA4ZH6pHUIviAotWqXMTsEKfupg9avxEk8GtH GZnXAmpnBqmbU4+3rNOaCZLdekVCoEtW0NGZEYEV5UQnZoWY6AiUUxGGE/qionKH sdKN+8CowudMH02bo1a0akS+eh+D/SGc/MLofH7uPWtX7l8sTvQivzDIkZeu -----END RSA PRIVATE KEY----- ================================================ FILE: tests/test__environment.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import os import unittest import asyncpg import asyncpg.serverversion from asyncpg import _testbase as tb class TestEnvironment(tb.ConnectedTestCase): @unittest.skipIf(not os.environ.get('PGVERSION'), "environ[PGVERSION] is not set") async def test_environment_server_version(self): pgver = os.environ.get('PGVERSION') env_ver = asyncpg.serverversion.split_server_version_string(pgver) srv_ver = self.con.get_server_version() self.assertEqual( env_ver[:2], srv_ver[:2], 'Expecting PostgreSQL version {pgver}, got {maj}.{min}.'.format( pgver=pgver, maj=srv_ver.major, min=srv_ver.minor) ) @unittest.skipIf(not os.environ.get('ASYNCPG_VERSION'), "environ[ASYNCPG_VERSION] is not set") @unittest.skipIf("dev" in asyncpg.__version__, "development version with git commit data") async def test_environment_asyncpg_version(self): apgver = os.environ.get('ASYNCPG_VERSION') self.assertEqual( asyncpg.__version__, apgver, 'Expecting asyncpg version {}, got {}.'.format( apgver, asyncpg.__version__) ) ================================================ FILE: tests/test__sourcecode.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import os import subprocess import sys import unittest def find_root(): return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) class TestCodeQuality(unittest.TestCase): def test_flake8(self): try: import flake8 # NoQA except ImportError: raise unittest.SkipTest('flake8 module is missing') root_path = find_root() config_path = os.path.join(root_path, '.flake8') if not os.path.exists(config_path): raise RuntimeError('could not locate .flake8 file') try: subprocess.run( [sys.executable, '-m', 'flake8', '--config', config_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=root_path) except subprocess.CalledProcessError as ex: output = ex.output.decode() raise AssertionError( 'flake8 validation failed:\n{}'.format(output)) from None def test_mypy(self): try: import mypy # NoQA except ImportError: raise unittest.SkipTest('mypy module is missing') root_path = find_root() config_path = os.path.join(root_path, 'pyproject.toml') if not os.path.exists(config_path): raise RuntimeError('could not locate mypy.ini file') try: subprocess.run( [ sys.executable, '-m', 'mypy', '--config-file', config_path, 'asyncpg' ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=root_path ) except subprocess.CalledProcessError as ex: output = ex.output.decode() raise AssertionError( 'mypy validation failed:\n{}'.format(output)) from None ================================================ FILE: tests/test_adversity.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 """Tests how asyncpg behaves in non-ideal conditions.""" import asyncio import os import platform import unittest from asyncpg import _testbase as tb @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') @unittest.skipIf( platform.system() == 'Windows', 'not compatible with ProactorEventLoop which is default in Python 3.8+') class TestConnectionLoss(tb.ProxiedClusterTestCase): @tb.with_timeout(30.0) async def test_connection_close_timeout(self): con = await self.connect() self.proxy.trigger_connectivity_loss() with self.assertRaises(asyncio.TimeoutError): await con.close(timeout=0.5) @tb.with_timeout(30.0) async def test_pool_acquire_timeout(self): pool = await self.create_pool( database='postgres', min_size=2, max_size=2) try: self.proxy.trigger_connectivity_loss() for _ in range(2): with self.assertRaises(asyncio.TimeoutError): async with pool.acquire(timeout=0.5): pass self.proxy.restore_connectivity() async with pool.acquire(timeout=0.5): pass finally: self.proxy.restore_connectivity() pool.terminate() @tb.with_timeout(30.0) async def test_pool_release_timeout(self): pool = await self.create_pool( database='postgres', min_size=2, max_size=2) try: with self.assertRaises(asyncio.TimeoutError): async with pool.acquire(timeout=0.5): self.proxy.trigger_connectivity_loss() finally: self.proxy.restore_connectivity() pool.terminate() @tb.with_timeout(30.0) async def test_pool_handles_abrupt_connection_loss(self): pool_size = 3 query_runtime = 0.5 pool_timeout = cmd_timeout = 1.0 concurrency = 9 pool_concurrency = (concurrency - 1) // pool_size + 1 # Worst expected runtime + 20% to account for other latencies. worst_runtime = (pool_timeout + cmd_timeout) * pool_concurrency * 1.2 async def worker(pool): async with pool.acquire(timeout=pool_timeout) as con: await con.fetch('SELECT pg_sleep($1)', query_runtime) def kill_connectivity(): self.proxy.trigger_connectivity_loss() new_pool = self.create_pool( database='postgres', min_size=pool_size, max_size=pool_size, timeout=cmd_timeout, command_timeout=cmd_timeout) with self.assertRunUnder(worst_runtime): pool = await new_pool try: workers = [worker(pool) for _ in range(concurrency)] self.loop.call_later(1, kill_connectivity) await asyncio.gather( *workers, return_exceptions=True) finally: pool.terminate() ================================================ FILE: tests/test_cache_invalidation.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncpg from asyncpg import _testbase as tb ERRNUM = 'unexpected number of attributes of composite type' ERRTYP = 'unexpected data type of composite type' class TestCacheInvalidation(tb.ConnectedTestCase): def _get_cached_statements(self, connection=None): if connection is None: connection = self.con return list(connection._stmt_cache.iter_statements()) def _check_statements_are_not_closed(self, statements): self.assertGreater(len(statements), 0) self.assertTrue(all(not s.closed for s in statements)) def _check_statements_are_closed(self, statements): self.assertGreater(len(statements), 0) self.assertTrue(all(s.closed for s in statements)) async def test_prepare_cache_invalidation_silent(self): await self.con.execute('CREATE TABLE tab1(a int, b int)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, 2)) statements = self._get_cached_statements() self._check_statements_are_not_closed(statements) await self.con.execute( 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, '2')) self._check_statements_are_closed(statements) finally: await self.con.execute('DROP TABLE tab1') async def test_prepare_cache_invalidation_in_transaction(self): await self.con.execute('CREATE TABLE tab1(a int, b int)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, 2)) statements = self._get_cached_statements() self._check_statements_are_not_closed(statements) await self.con.execute( 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError, 'cached statement plan is invalid'): async with self.con.transaction(): result = await self.con.fetchrow('SELECT * FROM tab1') self._check_statements_are_closed(statements) # This is now OK, result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, '2')) finally: await self.con.execute('DROP TABLE tab1') async def test_prepare_cache_invalidation_in_pool(self): pool = await self.create_pool(database='postgres', min_size=2, max_size=2) await self.con.execute('CREATE TABLE tab1(a int, b int)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') con1 = await pool.acquire() con2 = await pool.acquire() result = await con1.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, 2)) result = await con2.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, 2)) statements1 = self._get_cached_statements(con1) self._check_statements_are_not_closed(statements1) statements2 = self._get_cached_statements(con2) self._check_statements_are_not_closed(statements2) await self.con.execute( 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') # con1 tries the same plan, will invalidate the cache # for the entire pool. result = await con1.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, '2')) self._check_statements_are_closed(statements1) self._check_statements_are_closed(statements2) async with con2.transaction(): # This should work, as con1 should have invalidated # the plan cache. result = await con2.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, '2')) finally: await self.con.execute('DROP TABLE tab1') await pool.release(con2) await pool.release(con1) await pool.close() async def test_type_cache_invalidation_in_transaction(self): await self.con.execute('CREATE TYPE typ1 AS (x int, y int)') await self.con.execute('CREATE TABLE tab1(a int, b typ1)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, (2, 3))') result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3))) statements = self._get_cached_statements() self._check_statements_are_not_closed(statements) async with self.con.transaction(): await self.con.execute('ALTER TYPE typ1 ADD ATTRIBUTE c text') with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRNUM): await self.con.fetchrow('SELECT * FROM tab1') self._check_statements_are_closed(statements) # The second request must be correct (cache was dropped): result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3, None))) # This is now OK, the cache is actual after the transaction. result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3, None))) finally: await self.con.execute('DROP TABLE tab1') await self.con.execute('DROP TYPE typ1') async def test_type_cache_invalidation_in_cancelled_transaction(self): await self.con.execute('CREATE TYPE typ1 AS (x int, y int)') await self.con.execute('CREATE TABLE tab1(a int, b typ1)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, (2, 3))') result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3))) statements = self._get_cached_statements() self._check_statements_are_not_closed(statements) try: async with self.con.transaction(): await self.con.execute( 'ALTER TYPE typ1 ADD ATTRIBUTE c text') with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRNUM): await self.con.fetchrow('SELECT * FROM tab1') self._check_statements_are_closed(statements) # The second request must be correct (cache was dropped): result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3, None))) raise UserWarning # Just to generate ROLLBACK except UserWarning: pass with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRNUM): await self.con.fetchrow('SELECT * FROM tab1') # This is now OK, the cache is filled after being dropped. result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3))) finally: await self.con.execute('DROP TABLE tab1') await self.con.execute('DROP TYPE typ1') async def test_prepared_type_cache_invalidation(self): await self.con.execute('CREATE TYPE typ1 AS (x int, y int)') await self.con.execute('CREATE TABLE tab1(a int, b typ1)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, (2, 3))') prep = await self.con._prepare('SELECT * FROM tab1', use_cache=True) result = await prep.fetchrow() self.assertEqual(result, (1, (2, 3))) statements = self._get_cached_statements() self._check_statements_are_not_closed(statements) try: async with self.con.transaction(): await self.con.execute( 'ALTER TYPE typ1 ADD ATTRIBUTE c text') with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRNUM): await prep.fetchrow() self._check_statements_are_closed(statements) # PS has its local cache for types codecs, even after the # cache cleanup it is not possible to use it. # That's why it is marked as closed. with self.assertRaisesRegex( asyncpg.InterfaceError, 'the prepared statement is closed'): await prep.fetchrow() prep = await self.con._prepare('SELECT * FROM tab1', use_cache=True) # The second PS must be correct (cache was dropped): result = await prep.fetchrow() self.assertEqual(result, (1, (2, 3, None))) raise UserWarning # Just to generate ROLLBACK except UserWarning: pass with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRNUM): await prep.fetchrow() # Reprepare it again after dropping cache. prep = await self.con._prepare('SELECT * FROM tab1', use_cache=True) # This is now OK, the cache is filled after being dropped. result = await prep.fetchrow() self.assertEqual(result, (1, (2, 3))) finally: await self.con.execute('DROP TABLE tab1') await self.con.execute('DROP TYPE typ1') async def test_type_cache_invalidation_on_drop_type_attr(self): await self.con.execute('CREATE TYPE typ1 AS (x int, y int, c text)') await self.con.execute('CREATE TABLE tab1(a int, b typ1)') try: await self.con.execute( 'INSERT INTO tab1 VALUES (1, (2, 3, $1))', 'x') result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3, 'x'))) statements = self._get_cached_statements() self._check_statements_are_not_closed(statements) await self.con.execute('ALTER TYPE typ1 DROP ATTRIBUTE x') with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRNUM): await self.con.fetchrow('SELECT * FROM tab1') self._check_statements_are_closed(statements) # This is now OK, the cache is filled after being dropped. result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (3, 'x'))) finally: await self.con.execute('DROP TABLE tab1') await self.con.execute('DROP TYPE typ1') async def test_type_cache_invalidation_on_change_attr(self): await self.con.execute('CREATE TYPE typ1 AS (x int, y int)') await self.con.execute('CREATE TABLE tab1(a int, b typ1)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, (2, 3))') result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3))) statements = self._get_cached_statements() self._check_statements_are_not_closed(statements) # It is slightly artificial, but can take place in transactional # schema changing. Nevertheless, if the code checks and raises it # the most probable reason is a difference with the cache type. await self.con.execute('ALTER TYPE typ1 DROP ATTRIBUTE y') await self.con.execute('ALTER TYPE typ1 ADD ATTRIBUTE y bigint') with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRTYP): await self.con.fetchrow('SELECT * FROM tab1') self._check_statements_are_closed(statements) # This is now OK, the cache is filled after being dropped. result = await self.con.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, None))) finally: await self.con.execute('DROP TABLE tab1') await self.con.execute('DROP TYPE typ1') async def test_type_cache_invalidation_in_pool(self): await self.con.execute('CREATE DATABASE testdb') pool = await self.create_pool(database='postgres', min_size=2, max_size=2) pool_chk = await self.create_pool(database='testdb', min_size=2, max_size=2) await self.con.execute('CREATE TYPE typ1 AS (x int, y int)') await self.con.execute('CREATE TABLE tab1(a int, b typ1)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, (2, 3))') con1 = await pool.acquire() con2 = await pool.acquire() result = await con1.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3))) statements1 = self._get_cached_statements(con1) self._check_statements_are_not_closed(statements1) result = await con2.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3))) statements2 = self._get_cached_statements(con2) self._check_statements_are_not_closed(statements2) # Create the same schema in the "testdb", fetch data which caches # type info. con_chk = await pool_chk.acquire() await con_chk.execute('CREATE TYPE typ1 AS (x int, y int)') await con_chk.execute('CREATE TABLE tab1(a int, b typ1)') await con_chk.execute('INSERT INTO tab1 VALUES (1, (2, 3))') result = await con_chk.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3))) statements_chk = self._get_cached_statements(con_chk) self._check_statements_are_not_closed(statements_chk) # Change schema in the databases. await self.con.execute('ALTER TYPE typ1 ADD ATTRIBUTE c text') await con_chk.execute('ALTER TYPE typ1 ADD ATTRIBUTE c text') # con1 tries to get cached type info, fails, but invalidates the # cache for the entire pool. with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRNUM): await con1.fetchrow('SELECT * FROM tab1') self._check_statements_are_closed(statements1) self._check_statements_are_closed(statements2) async with con2.transaction(): # This should work, as con1 should have invalidated all caches. result = await con2.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3, None))) # After all the con1 uses actual info from renewed cache entry. result = await con1.fetchrow('SELECT * FROM tab1') self.assertEqual(result, (1, (2, 3, None))) # Check the invalidation is database-specific, i.e. cache entries # for pool_chk/con_chk was not dropped via pool/con1. self._check_statements_are_not_closed(statements_chk) with self.assertRaisesRegex( asyncpg.OutdatedSchemaCacheError, ERRNUM): await con_chk.fetchrow('SELECT * FROM tab1') self._check_statements_are_closed(statements_chk) finally: await self.con.execute('DROP TABLE tab1') await self.con.execute('DROP TYPE typ1') await pool.release(con2) await pool.release(con1) await pool.close() await pool_chk.release(con_chk) await pool_chk.close() await self.con.execute('DROP DATABASE testdb') ================================================ FILE: tests/test_cancellation.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import asyncpg from asyncpg import _testbase as tb class TestCancellation(tb.ConnectedTestCase): async def test_cancellation_01(self): st1000 = await self.con.prepare('SELECT 1000') async def test0(): val = await self.con.execute('SELECT 42') self.assertEqual(val, 'SELECT 1') async def test1(): val = await self.con.fetchval('SELECT 42') self.assertEqual(val, 42) async def test2(): val = await self.con.fetchrow('SELECT 42') self.assertEqual(val, (42,)) async def test3(): val = await self.con.fetch('SELECT 42') self.assertEqual(val, [(42,)]) async def test4(): val = await self.con.prepare('SELECT 42') self.assertEqual(await val.fetchval(), 42) async def test5(): self.assertEqual(await st1000.fetchval(), 1000) async def test6(): self.assertEqual(await st1000.fetchrow(), (1000,)) async def test7(): self.assertEqual(await st1000.fetch(), [(1000,)]) async def test8(): cur = await st1000.cursor() self.assertEqual(await cur.fetchrow(), (1000,)) for test in {test0, test1, test2, test3, test4, test5, test6, test7, test8}: with self.subTest(testfunc=test), self.assertRunUnder(1): st = await self.con.prepare('SELECT pg_sleep(20)') task = self.loop.create_task(st.fetch()) await asyncio.sleep(0.05) task.cancel() with self.assertRaises(asyncio.CancelledError): await task async with self.con.transaction(): await test() async def test_cancellation_02(self): st = await self.con.prepare('SELECT 1') task = self.loop.create_task(st.fetch()) await asyncio.sleep(0.05) task.cancel() self.assertEqual(await task, [(1,)]) async def test_cancellation_03(self): with self.assertRaises(asyncpg.InFailedSQLTransactionError): async with self.con.transaction(): task = self.loop.create_task( self.con.fetch('SELECT pg_sleep(20)')) await asyncio.sleep(0.05) task.cancel() with self.assertRaises(asyncio.CancelledError): await task await self.con.fetch('SELECT generate_series(0, 100)') self.assertEqual( await self.con.fetchval('SELECT 42'), 42) async def test_cancellation_04(self): await self.con.fetchval('SELECT pg_sleep(0)') waiter = asyncio.Future() self.con._cancel_current_command(waiter) await waiter self.assertEqual(await self.con.fetchval('SELECT 42'), 42) ================================================ FILE: tests/test_codecs.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import datetime import decimal import ipaddress import math import os import random import struct import unittest import uuid import asyncpg from asyncpg import _testbase as tb from asyncpg import cluster as pg_cluster def _timezone(offset): minutes = offset // 60 return datetime.timezone(datetime.timedelta(minutes=minutes)) def _system_timezone(): d = datetime.datetime.now(datetime.timezone.utc).astimezone() return datetime.timezone(d.utcoffset()) infinity_datetime = datetime.datetime( datetime.MAXYEAR, 12, 31, 23, 59, 59, 999999) negative_infinity_datetime = datetime.datetime( datetime.MINYEAR, 1, 1, 0, 0, 0, 0) infinity_date = datetime.date(datetime.MAXYEAR, 12, 31) negative_infinity_date = datetime.date(datetime.MINYEAR, 1, 1) current_timezone = _system_timezone() current_date = datetime.date.today() current_datetime = datetime.datetime.now() type_samples = [ ('bool', 'bool', ( True, False, )), ('smallint', 'int2', ( -2 ** 15, 2 ** 15 - 1, -1, 0, 1, )), ('int', 'int4', ( -2 ** 31, 2 ** 31 - 1, -1, 0, 1, )), ('bigint', 'int8', ( -2 ** 63, 2 ** 63 - 1, -1, 0, 1, )), ('numeric', 'numeric', ( -(2 ** 64), 2 ** 64, -(2 ** 128), 2 ** 128, -1, 0, 1, decimal.Decimal("0.00000000000000"), decimal.Decimal("1.00000000000000"), decimal.Decimal("-1.00000000000000"), decimal.Decimal("-2.00000000000000"), decimal.Decimal("1000000000000000.00000000000000"), decimal.Decimal(1234), decimal.Decimal(-1234), decimal.Decimal("1234000000.00088883231"), decimal.Decimal(str(1234.00088883231)), decimal.Decimal("3123.23111"), decimal.Decimal("-3123000000.23111"), decimal.Decimal("3123.2311100000"), decimal.Decimal("-03123.0023111"), decimal.Decimal("3123.23111"), decimal.Decimal("3123.23111"), decimal.Decimal("10000.23111"), decimal.Decimal("100000.23111"), decimal.Decimal("1000000.23111"), decimal.Decimal("10000000.23111"), decimal.Decimal("100000000.23111"), decimal.Decimal("1000000000.23111"), decimal.Decimal("1000000000.3111"), decimal.Decimal("1000000000.111"), decimal.Decimal("1000000000.11"), decimal.Decimal("100000000.0"), decimal.Decimal("10000000.0"), decimal.Decimal("1000000.0"), decimal.Decimal("100000.0"), decimal.Decimal("10000.0"), decimal.Decimal("1000.0"), decimal.Decimal("100.0"), decimal.Decimal("100"), decimal.Decimal("100.1"), decimal.Decimal("100.12"), decimal.Decimal("100.123"), decimal.Decimal("100.1234"), decimal.Decimal("100.12345"), decimal.Decimal("100.123456"), decimal.Decimal("100.1234567"), decimal.Decimal("100.12345679"), decimal.Decimal("100.123456790"), decimal.Decimal("100.123456790000000000000000"), decimal.Decimal("1.0"), decimal.Decimal("0.0"), decimal.Decimal("-1.0"), decimal.Decimal("1.0E-1000"), decimal.Decimal("1E1000"), decimal.Decimal("0.000000000000000000000000001"), decimal.Decimal("0.000000000000010000000000001"), decimal.Decimal("0.00000000000000000000000001"), decimal.Decimal("0.00000000100000000000000001"), decimal.Decimal("0.0000000000000000000000001"), decimal.Decimal("0.000000000000000000000001"), decimal.Decimal("0.00000000000000000000001"), decimal.Decimal("0.0000000000000000000001"), decimal.Decimal("0.000000000000000000001"), decimal.Decimal("0.00000000000000000001"), decimal.Decimal("0.0000000000000000001"), decimal.Decimal("0.000000000000000001"), decimal.Decimal("0.00000000000000001"), decimal.Decimal("0.0000000000000001"), decimal.Decimal("0.000000000000001"), decimal.Decimal("0.00000000000001"), decimal.Decimal("0.0000000000001"), decimal.Decimal("0.000000000001"), decimal.Decimal("0.00000000001"), decimal.Decimal("0.0000000001"), decimal.Decimal("0.000000001"), decimal.Decimal("0.00000001"), decimal.Decimal("0.0000001"), decimal.Decimal("0.000001"), decimal.Decimal("0.00001"), decimal.Decimal("0.0001"), decimal.Decimal("0.001"), decimal.Decimal("0.01"), decimal.Decimal("0.1"), decimal.Decimal("0.10"), decimal.Decimal("0.100"), decimal.Decimal("0.1000"), decimal.Decimal("0.10000"), decimal.Decimal("0.100000"), decimal.Decimal("0.00001000"), decimal.Decimal("0.000010000"), decimal.Decimal("0.0000100000"), decimal.Decimal("0.00001000000"), decimal.Decimal("1" + "0" * 117 + "." + "0" * 161) )), ('bytea', 'bytea', ( bytes(range(256)), bytes(range(255, -1, -1)), b'\x00\x00', b'foo', b'f' * 1024 * 1024, dict(input=bytearray(b'\x02\x01'), output=b'\x02\x01'), )), ('text', 'text', ( '', 'A' * (1024 * 1024 + 11) )), ('"char"', 'char', ( b'a', b'b', b'\x00' )), ('timestamp', 'timestamp', [ datetime.datetime(3000, 5, 20, 5, 30, 10), datetime.datetime(2000, 1, 1, 5, 25, 10), datetime.datetime(500, 1, 1, 5, 25, 10), datetime.datetime(250, 1, 1, 5, 25, 10), infinity_datetime, negative_infinity_datetime, {'textinput': 'infinity', 'output': infinity_datetime}, {'textinput': '-infinity', 'output': negative_infinity_datetime}, {'input': datetime.date(2000, 1, 1), 'output': datetime.datetime(2000, 1, 1)}, {'textinput': '1970-01-01 20:31:23.648', 'output': datetime.datetime(1970, 1, 1, 20, 31, 23, 648000)}, {'input': datetime.datetime(1970, 1, 1, 20, 31, 23, 648000), 'textoutput': '1970-01-01 20:31:23.648'}, ]), ('date', 'date', [ datetime.date(3000, 5, 20), datetime.date(2000, 1, 1), datetime.date(500, 1, 1), infinity_date, negative_infinity_date, {'textinput': 'infinity', 'output': infinity_date}, {'textinput': '-infinity', 'output': negative_infinity_date}, ]), ('time', 'time', [ datetime.time(12, 15, 20), datetime.time(0, 1, 1), datetime.time(23, 59, 59), ]), ('timestamptz', 'timestamptz', [ # It's converted to UTC. When it comes back out, it will be in UTC # again. The datetime comparison will take the tzinfo into account. datetime.datetime(1990, 5, 12, 10, 10, 0, tzinfo=_timezone(4000)), datetime.datetime(1982, 5, 18, 10, 10, 0, tzinfo=_timezone(6000)), datetime.datetime(1950, 1, 1, 10, 10, 0, tzinfo=_timezone(7000)), datetime.datetime(1800, 1, 1, 10, 10, 0, tzinfo=_timezone(2000)), datetime.datetime(2400, 1, 1, 10, 10, 0, tzinfo=_timezone(2000)), infinity_datetime, negative_infinity_datetime, { 'input': current_date, 'output': datetime.datetime( year=current_date.year, month=current_date.month, day=current_date.day, tzinfo=current_timezone), }, { 'input': current_datetime, 'output': current_datetime.replace(tzinfo=current_timezone), } ]), ('timetz', 'timetz', [ # timetz retains the offset datetime.time(10, 10, 0, tzinfo=_timezone(4000)), datetime.time(10, 10, 0, tzinfo=_timezone(6000)), datetime.time(10, 10, 0, tzinfo=_timezone(7000)), datetime.time(10, 10, 0, tzinfo=_timezone(2000)), datetime.time(22, 30, 0, tzinfo=_timezone(0)), ]), ('interval', 'interval', [ datetime.timedelta(40, 10, 1234), datetime.timedelta(0, 0, 4321), datetime.timedelta(0, 0), datetime.timedelta(-100, 0), datetime.timedelta(-100, -400), { 'textinput': '-2 years -11 months -10 days ' '-2 hours -800 milliseconds', 'output': datetime.timedelta( days=(-2 * 365) + (-11 * 30) - 10, seconds=(-2 * 3600), milliseconds=-800 ), }, { 'query': 'SELECT justify_hours($1::interval)::text', 'input': datetime.timedelta( days=(-2 * 365) + (-11 * 30) - 10, seconds=(-2 * 3600), milliseconds=-800 ), 'textoutput': '-1070 days -02:00:00.8', }, ]), ('uuid', 'uuid', [ uuid.UUID('38a4ff5a-3a56-11e6-a6c2-c8f73323c6d4'), uuid.UUID('00000000-0000-0000-0000-000000000000'), {'input': '00000000-0000-0000-0000-000000000000', 'output': uuid.UUID('00000000-0000-0000-0000-000000000000')} ]), ('uuid[]', 'uuid[]', [ [uuid.UUID('38a4ff5a-3a56-11e6-a6c2-c8f73323c6d4'), uuid.UUID('00000000-0000-0000-0000-000000000000')], [] ]), ('json', 'json', [ '[1, 2, 3, 4]', '{"a": [1, 2], "b": 0}' ]), ('jsonb', 'jsonb', [ '[1, 2, 3, 4]', '{"a": [1, 2], "b": 0}' ], (9, 4)), ('jsonpath', 'jsonpath', [ '$."track"."segments"[*]."HR"?(@ > 130)', ], (12, 0)), ('oid[]', 'oid[]', [ [1, 2, 3, 4], [] ]), ('smallint[]', 'int2[]', [ [1, 2, 3, 4], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], [] ]), ('bigint[]', 'int8[]', [ [2 ** 42, -2 ** 54, 0], [] ]), ('int[]', 'int4[]', [ [2 ** 22, -2 ** 24, 0], [] ]), ('time[]', 'time[]', [ [datetime.time(12, 15, 20), datetime.time(0, 1, 1)], [] ]), ('text[]', 'text[]', [ ['ABCDE', 'EDCBA'], [], ['A' * 1024 * 1024] * 10 ]), ('float8', 'float8', [ 1.1, -1.1, 0, 2, 1e-4, -1e-20, 122.2e-100, 2e5, math.pi, math.e, math.inf, -math.inf, math.nan, {'textinput': 'infinity', 'output': math.inf}, {'textinput': '-infinity', 'output': -math.inf}, {'textinput': 'NaN', 'output': math.nan}, ]), ('float4', 'float4', [ 1.1, -1.1, 0, 2, 1e-4, -1e-20, 2e5, math.pi, math.e, math.inf, -math.inf, math.nan, {'textinput': 'infinity', 'output': math.inf}, {'textinput': '-infinity', 'output': -math.inf}, {'textinput': 'NaN', 'output': math.nan}, ]), ('cidr', 'cidr', [ ipaddress.IPv4Network('255.255.255.255/32'), ipaddress.IPv4Network('127.0.0.0/8'), ipaddress.IPv4Network('127.1.0.0/16'), ipaddress.IPv4Network('127.1.0.0/18'), ipaddress.IPv4Network('10.0.0.0/32'), ipaddress.IPv4Network('0.0.0.0/0'), ipaddress.IPv6Network('ffff' + ':ffff' * 7 + '/128'), ipaddress.IPv6Network('::1/128'), ipaddress.IPv6Network('::/0'), ]), ('inet', 'inet', [ ipaddress.IPv4Address('255.255.255.255'), ipaddress.IPv4Address('127.0.0.1'), ipaddress.IPv4Address('0.0.0.0'), ipaddress.IPv6Address('ffff' + ':ffff' * 7), ipaddress.IPv6Address('::1'), ipaddress.IPv6Address('::'), ipaddress.IPv4Interface('10.0.0.1/30'), ipaddress.IPv4Interface('0.0.0.0/0'), ipaddress.IPv4Interface('255.255.255.255/31'), dict( input='127.0.0.0/8', output=ipaddress.IPv4Interface('127.0.0.0/8')), dict( input='127.0.0.1/32', output=ipaddress.IPv4Address('127.0.0.1')), # Postgres appends /32 when casting to text explicitly, but # *not* in inet_out. dict( input='10.11.12.13', textoutput='10.11.12.13/32' ), dict( input=ipaddress.IPv4Address('10.11.12.13'), textoutput='10.11.12.13/32' ), dict( input=ipaddress.IPv4Interface('10.11.12.13'), textoutput='10.11.12.13/32' ), dict( textinput='10.11.12.13', output=ipaddress.IPv4Address('10.11.12.13'), ), dict( textinput='10.11.12.13/0', output=ipaddress.IPv4Interface('10.11.12.13/0'), ), ]), ('macaddr', 'macaddr', [ '00:00:00:00:00:00', 'ff:ff:ff:ff:ff:ff' ]), ('txid_snapshot', 'txid_snapshot', [ (100, 1000, (100, 200, 300, 400)) ]), ('pg_snapshot', 'pg_snapshot', [ (100, 1000, (100, 200, 300, 400)) ], (13, 0)), ('xid', 'xid', ( 2 ** 32 - 1, 0, 1, )), ('xid8', 'xid8', ( 2 ** 64 - 1, 0, 1, ), (13, 0)), ('varbit', 'varbit', [ asyncpg.BitString('0000 0001'), asyncpg.BitString('00010001'), asyncpg.BitString(''), asyncpg.BitString(), asyncpg.BitString.frombytes(b'\x00', bitlength=3), asyncpg.BitString('0000 0000 1'), dict(input=b'\x01', output=asyncpg.BitString('0000 0001')), dict(input=bytearray(b'\x02'), output=asyncpg.BitString('0000 0010')), ]), ('path', 'path', [ asyncpg.Path(asyncpg.Point(0.0, 0.0), asyncpg.Point(1.0, 1.0)), asyncpg.Path(asyncpg.Point(0.0, 0.0), asyncpg.Point(1.0, 1.0), is_closed=True), dict(input=((0.0, 0.0), (1.0, 1.0)), output=asyncpg.Path(asyncpg.Point(0.0, 0.0), asyncpg.Point(1.0, 1.0), is_closed=True)), dict(input=[(0.0, 0.0), (1.0, 1.0)], output=asyncpg.Path(asyncpg.Point(0.0, 0.0), asyncpg.Point(1.0, 1.0), is_closed=False)), ]), ('point', 'point', [ asyncpg.Point(0.0, 0.0), asyncpg.Point(1.0, 2.0), ]), ('box', 'box', [ asyncpg.Box((1.0, 2.0), (0.0, 0.0)), ]), ('line', 'line', [ asyncpg.Line(1, 2, 3), ], (9, 4)), ('lseg', 'lseg', [ asyncpg.LineSegment((1, 2), (2, 2)), ]), ('polygon', 'polygon', [ asyncpg.Polygon(asyncpg.Point(0.0, 0.0), asyncpg.Point(1.0, 0.0), asyncpg.Point(1.0, 1.0), asyncpg.Point(0.0, 1.0)), ]), ('circle', 'circle', [ asyncpg.Circle((0.0, 0.0), 100), ]), ('tid', 'tid', [ (100, 200), (0, 0), (2147483647, 0), (4294967295, 0), (0, 32767), (0, 65535), (4294967295, 65535), ]), ('oid', 'oid', [ 0, 10, 4294967295 ]) ] class TestCodecs(tb.ConnectedTestCase): async def test_standard_codecs(self): """Test encoding/decoding of standard data types and arrays thereof.""" for (typname, intname, sample_data, *metadata) in type_samples: if metadata and self.server_version < metadata[0]: continue st = await self.con.prepare( "SELECT $1::" + typname ) text_in = await self.con.prepare( "SELECT $1::text::" + typname ) text_out = await self.con.prepare( "SELECT $1::" + typname + "::text" ) for sample in sample_data: with self.subTest(sample=sample, typname=typname): stmt = st if isinstance(sample, dict): if 'textinput' in sample: inputval = sample['textinput'] stmt = text_in else: inputval = sample['input'] if 'textoutput' in sample: outputval = sample['textoutput'] if stmt is text_in: raise ValueError( 'cannot test "textin" and' ' "textout" simultaneously') stmt = text_out else: outputval = sample['output'] if sample.get('query'): stmt = await self.con.prepare(sample['query']) else: inputval = outputval = sample result = await stmt.fetchval(inputval) err_msg = ( "unexpected result for {} when passing {!r}: " "received {!r}, expected {!r}".format( typname, inputval, result, outputval)) if typname.startswith('float'): if math.isnan(outputval): if not math.isnan(result): self.fail(err_msg) else: self.assertTrue( math.isclose(result, outputval, rel_tol=1e-6), err_msg) else: self.assertEqual(result, outputval, err_msg) if (typname == 'numeric' and isinstance(inputval, decimal.Decimal)): self.assertEqual( result.as_tuple(), outputval.as_tuple(), err_msg, ) with self.subTest(sample=None, typname=typname): # Test that None is handled for all types. rsample = await st.fetchval(None) self.assertIsNone(rsample) at = st.get_attributes() self.assertEqual(at[0].type.name, intname) async def test_all_builtin_types_handled(self): from asyncpg.protocol.protocol import BUILTIN_TYPE_OID_MAP for oid, typename in BUILTIN_TYPE_OID_MAP.items(): codec = self.con.get_settings().get_data_codec(oid) self.assertIsNotNone( codec, 'core type {} ({}) is unhandled'.format(typename, oid)) async def test_void(self): res = await self.con.fetchval('select pg_sleep(0)') self.assertIsNone(res) await self.con.fetchval('select now($1::void)', '') def test_bitstring(self): bitlen = random.randint(0, 1000) bs = ''.join(random.choice(('1', '0', ' ')) for _ in range(bitlen)) bits = asyncpg.BitString(bs) sanitized_bs = bs.replace(' ', '') self.assertEqual(sanitized_bs, bits.as_string().replace(' ', '')) expected_bytelen = \ len(sanitized_bs) // 8 + (1 if len(sanitized_bs) % 8 else 0) self.assertEqual(len(bits.bytes), expected_bytelen) little, big = bits.to_int('little'), bits.to_int('big') self.assertEqual(bits.from_int(little, len(bits), 'little'), bits) self.assertEqual(bits.from_int(big, len(bits), 'big'), bits) naive_little = 0 for i, c in enumerate(sanitized_bs): naive_little |= int(c) << i naive_big = 0 for c in sanitized_bs: naive_big = (naive_big << 1) | int(c) self.assertEqual(little, naive_little) self.assertEqual(big, naive_big) async def test_interval(self): res = await self.con.fetchval("SELECT '5 years'::interval") self.assertEqual(res, datetime.timedelta(days=1825)) res = await self.con.fetchval("SELECT '5 years 1 month'::interval") self.assertEqual(res, datetime.timedelta(days=1855)) res = await self.con.fetchval("SELECT '-5 years'::interval") self.assertEqual(res, datetime.timedelta(days=-1825)) res = await self.con.fetchval("SELECT '-5 years -1 month'::interval") self.assertEqual(res, datetime.timedelta(days=-1855)) async def test_numeric(self): # Test that we handle dscale correctly. cases = [ '0.001', '0.001000', '1', '1.00000' ] for case in cases: res = await self.con.fetchval( "SELECT $1::numeric", case) self.assertEqual(str(res), case) try: await self.con.execute( ''' CREATE TABLE tab (v numeric(3, 2)); INSERT INTO tab VALUES (0), (1); ''') res = await self.con.fetchval("SELECT v FROM tab WHERE v = $1", 0) self.assertEqual(str(res), '0.00') res = await self.con.fetchval("SELECT v FROM tab WHERE v = $1", 1) self.assertEqual(str(res), '1.00') finally: await self.con.execute('DROP TABLE tab') res = await self.con.fetchval( "SELECT $1::numeric", decimal.Decimal('NaN')) self.assertTrue(res.is_nan()) res = await self.con.fetchval( "SELECT $1::numeric", decimal.Decimal('sNaN')) self.assertTrue(res.is_nan()) if self.server_version < (14, 0): with self.assertRaisesRegex( asyncpg.DataError, 'invalid sign in external "numeric" value' ): await self.con.fetchval( "SELECT $1::numeric", decimal.Decimal('-Inf')) with self.assertRaisesRegex( asyncpg.DataError, 'invalid sign in external "numeric" value' ): await self.con.fetchval( "SELECT $1::numeric", decimal.Decimal('+Inf')) with self.assertRaisesRegex(asyncpg.DataError, 'invalid'): await self.con.fetchval( "SELECT $1::numeric", 'invalid') else: res = await self.con.fetchval( "SELECT $1::numeric", decimal.Decimal("-Inf")) self.assertTrue(res.is_infinite()) res = await self.con.fetchval( "SELECT $1::numeric", decimal.Decimal("+Inf")) self.assertTrue(res.is_infinite()) with self.assertRaisesRegex(asyncpg.DataError, 'invalid'): await self.con.fetchval( "SELECT $1::numeric", 'invalid') async def test_unhandled_type_fallback(self): await self.con.execute(''' CREATE EXTENSION IF NOT EXISTS isn ''') try: input_val = '1436-4522' res = await self.con.fetchrow(''' SELECT $1::issn AS issn, 42 AS int ''', input_val) self.assertEqual(res['issn'], input_val) self.assertEqual(res['int'], 42) finally: await self.con.execute(''' DROP EXTENSION isn ''') async def test_invalid_input(self): # The latter message appears beginning in Python 3.10. integer_required = ( r"(an integer is required|" r"\('str' object cannot be interpreted as an integer\))") cases = [ ('bytea', 'a bytes-like object is required', [ 1, 'aaa' ]), ('bool', 'a boolean is required', [ 1, ]), ('int2', integer_required, [ '2', 'aa', ]), ('smallint', 'value out of int16 range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffff, 0xffffffff, 32768, -32769 ]), ('float4', 'value out of float32 range', [ 4.1 * 10 ** 40, -4.1 * 10 ** 40, ]), ('int4', integer_required, [ '2', 'aa', ]), ('int', 'value out of int32 range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffffffff, 2**31, -2**31 - 1, ]), ('int8', integer_required, [ '2', 'aa', ]), ('bigint', 'value out of int64 range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffffffffffffffff, 2**63, -2**63 - 1, ]), ('text', 'expected str, got bytes', [ b'foo' ]), ('text', 'expected str, got list', [ [1] ]), ('tid', 'list or tuple expected', [ b'foo' ]), ('tid', 'invalid number of elements in tid tuple', [ [], (), [1, 2, 3], (4,), ]), ('tid', 'tuple id block value out of uint32 range', [ (-1, 0), (2**256, 0), (0xffffffff + 1, 0), (2**32, 0), ]), ('tid', 'tuple id offset value out of uint16 range', [ (0, -1), (0, 2**256), (0, 0xffff + 1), (0, 0xffffffff), (0, 65536), ]), ('oid', 'value out of uint32 range', [ 2 ** 32, -1, ]), ('timestamp', r"expected a datetime\.date.*got 'str'", [ 'foo' ]), ('timestamptz', r"expected a datetime\.date.*got 'str'", [ 'foo' ]), ] for typname, errmsg, data in cases: stmt = await self.con.prepare("SELECT $1::" + typname) for sample in data: with self.subTest(sample=sample, typname=typname): full_errmsg = ( r'invalid input for query argument \$1:.*' + errmsg) with self.assertRaisesRegex( asyncpg.DataError, full_errmsg): await stmt.fetchval(sample) async def test_arrays(self): """Test encoding/decoding of arrays (particularly multidimensional).""" cases = [ ( r"SELECT '[1:3][-1:0]={{1,2},{4,5},{6,7}}'::int[]", [[1, 2], [4, 5], [6, 7]] ), ( r"SELECT '{{{{{{1}}}}}}'::int[]", [[[[[[1]]]]]] ), ( r"SELECT '{1, 2, NULL}'::int[]::anyarray", [1, 2, None] ), ( r"SELECT '{}'::int[]", [] ), ] for sql, expected in cases: with self.subTest(sql=sql): res = await self.con.fetchval(sql) self.assertEqual(res, expected) with self.assertRaises(asyncpg.ProgramLimitExceededError): await self.con.fetchval("SELECT '{{{{{{{1}}}}}}}'::int[]") cases = [ [None], [1, 2, 3, 4, 5, 6], [[1, 2], [4, 5], [6, 7]], [[[1], [2]], [[4], [5]], [[None], [7]]], [[[[[[1]]]]]], [[[[[[None]]]]]] ] st = await self.con.prepare( "SELECT $1::int[]" ) for case in cases: with self.subTest(case=case): result = await st.fetchval(case) err_msg = ( "failed to return array data as-is; " "gave {!r}, received {!r}".format( case, result)) self.assertEqual(result, case, err_msg) # A sized iterable is fine as array input. class Iterable: def __iter__(self): return iter([1, 2, 3]) def __len__(self): return 3 result = await self.con.fetchval("SELECT $1::int[]", Iterable()) self.assertEqual(result, [1, 2, 3]) # A pure container is _not_ OK for array input. class SomeContainer: def __contains__(self, item): return False with self.assertRaisesRegex(asyncpg.DataError, 'sized iterable container expected'): result = await self.con.fetchval("SELECT $1::int[]", SomeContainer()) with self.assertRaisesRegex(asyncpg.DataError, 'dimensions'): await self.con.fetchval( "SELECT $1::int[]", [[[[[[[1]]]]]]]) with self.assertRaisesRegex(asyncpg.DataError, 'non-homogeneous'): await self.con.fetchval( "SELECT $1::int[]", [1, [1]]) with self.assertRaisesRegex(asyncpg.DataError, 'non-homogeneous'): await self.con.fetchval( "SELECT $1::int[]", [[1], 1, [2]]) with self.assertRaisesRegex(asyncpg.DataError, 'invalid array element'): await self.con.fetchval( "SELECT $1::int[]", [1, 't', 2]) with self.assertRaisesRegex(asyncpg.DataError, 'invalid array element'): await self.con.fetchval( "SELECT $1::int[]", [[1], ['t'], [2]]) with self.assertRaisesRegex(asyncpg.DataError, 'sized iterable container expected'): await self.con.fetchval( "SELECT $1::int[]", 1) async def test_composites(self): """Test encoding/decoding of composite types.""" await self.con.execute(''' CREATE TYPE test_composite AS ( a int, b text, c int[] ) ''') st = await self.con.prepare(''' SELECT ROW(NULL, 1234, '5678', ROW(42, '42')) ''') res = await st.fetchval() self.assertEqual(res, (None, 1234, '5678', (42, '42'))) with self.assertRaisesRegex( asyncpg.UnsupportedClientFeatureError, 'query argument \\$1: input of anonymous ' 'composite types is not supported', ): await self.con.fetchval("SELECT (1, 'foo') = $1", (1, 'foo')) try: st = await self.con.prepare(''' SELECT ROW( NULL, '5678', ARRAY[9, NULL, 11]::int[] )::test_composite AS test ''') res = await st.fetch() res = res[0]['test'] self.assertIsNone(res['a']) self.assertEqual(res['b'], '5678') self.assertEqual(res['c'], [9, None, 11]) self.assertIsNone(res[0]) self.assertEqual(res[1], '5678') self.assertEqual(res[2], [9, None, 11]) at = st.get_attributes() self.assertEqual(len(at), 1) self.assertEqual(at[0].name, 'test') self.assertEqual(at[0].type.name, 'test_composite') self.assertEqual(at[0].type.kind, 'composite') res = await self.con.fetchval(''' SELECT $1::test_composite ''', res) # composite input as a mapping res = await self.con.fetchval(''' SELECT $1::test_composite ''', {'b': 'foo', 'a': 1, 'c': [1, 2, 3]}) self.assertEqual(res, (1, 'foo', [1, 2, 3])) # Test None padding res = await self.con.fetchval(''' SELECT $1::test_composite ''', {'a': 1}) self.assertEqual(res, (1, None, None)) with self.assertRaisesRegex( asyncpg.DataError, "'bad' is not a valid element"): await self.con.fetchval( "SELECT $1::test_composite", {'bad': 'foo'}) finally: await self.con.execute('DROP TYPE test_composite') async def test_domains(self): """Test encoding/decoding of composite types.""" await self.con.execute(''' CREATE DOMAIN my_dom AS int ''') await self.con.execute(''' CREATE DOMAIN my_dom2 AS my_dom ''') try: st = await self.con.prepare(''' SELECT 3::my_dom2 ''') res = await st.fetchval() self.assertEqual(res, 3) st = await self.con.prepare(''' SELECT NULL::my_dom2 ''') res = await st.fetchval() self.assertIsNone(res) at = st.get_attributes() self.assertEqual(len(at), 1) self.assertEqual(at[0].name, 'my_dom2') self.assertEqual(at[0].type.name, 'int4') self.assertEqual(at[0].type.kind, 'scalar') finally: await self.con.execute('DROP DOMAIN my_dom2') await self.con.execute('DROP DOMAIN my_dom') async def test_range_types(self): """Test encoding/decoding of range types.""" cases = [ ('int4range', [ [(1, 9), asyncpg.Range(1, 10)], [asyncpg.Range(0, 9, lower_inc=False, upper_inc=True), asyncpg.Range(1, 10)], [(), asyncpg.Range(empty=True)], [asyncpg.Range(empty=True), asyncpg.Range(empty=True)], [(None, 2), asyncpg.Range(None, 3)], [asyncpg.Range(None, 2, upper_inc=True), asyncpg.Range(None, 3)], [(2,), asyncpg.Range(2, None)], [(2, None), asyncpg.Range(2, None)], [asyncpg.Range(2, None), asyncpg.Range(2, None)], [(None, None), asyncpg.Range(None, None)], [asyncpg.Range(None, None), asyncpg.Range(None, None)] ]) ] for (typname, sample_data) in cases: st = await self.con.prepare( "SELECT $1::" + typname ) for sample, expected in sample_data: with self.subTest(sample=sample, typname=typname): result = await st.fetchval(sample) self.assertEqual(result, expected) with self.assertRaisesRegex( asyncpg.DataError, 'list, tuple or Range object expected'): await self.con.fetch("SELECT $1::int4range", 'aa') with self.assertRaisesRegex( asyncpg.DataError, 'expected 0, 1 or 2 elements'): await self.con.fetch("SELECT $1::int4range", (0, 2, 3)) cases = [(asyncpg.Range(0, 1), asyncpg.Range(0, 1), 1), (asyncpg.Range(0, 1), asyncpg.Range(0, 2), 2), (asyncpg.Range(empty=True), asyncpg.Range(0, 2), 2), (asyncpg.Range(empty=True), asyncpg.Range(empty=True), 1), (asyncpg.Range(0, 1, upper_inc=True), asyncpg.Range(0, 1), 2), ] for obj_a, obj_b, count in cases: dic = {obj_a: 1, obj_b: 2} self.assertEqual(len(dic), count) async def test_multirange_types(self): """Test encoding/decoding of multirange types.""" if self.server_version < (14, 0): self.skipTest("this server does not support multirange types") cases = [ ('int4multirange', [ [ [], [] ], [ [()], [] ], [ [asyncpg.Range(empty=True)], [] ], [ [asyncpg.Range(0, 9, lower_inc=False, upper_inc=True)], [asyncpg.Range(1, 10)] ], [ [(1, 9), (9, 11)], [asyncpg.Range(1, 12)] ], [ [(1, 9), (20, 30)], [asyncpg.Range(1, 10), asyncpg.Range(20, 31)] ], [ [(None, 2)], [asyncpg.Range(None, 3)], ] ]) ] for (typname, sample_data) in cases: st = await self.con.prepare( "SELECT $1::" + typname ) for sample, expected in sample_data: with self.subTest(sample=sample, typname=typname): result = await st.fetchval(sample) self.assertEqual(result, expected) with self.assertRaisesRegex( asyncpg.DataError, 'expected a sequence'): await self.con.fetch("SELECT $1::int4multirange", 1) async def test_extra_codec_alias(self): """Test encoding/decoding of a builtin non-pg_catalog codec.""" await self.con.execute(''' CREATE DOMAIN my_dec_t AS decimal; CREATE EXTENSION IF NOT EXISTS hstore; CREATE TYPE rec_t AS ( i my_dec_t, h hstore ); ''') try: await self.con.set_builtin_type_codec( 'hstore', codec_name='pg_contrib.hstore') cases = [ {'ham': 'spam', 'nada': None}, {} ] st = await self.con.prepare(''' SELECT $1::hstore AS result ''') for case in cases: res = await st.fetchval(case) self.assertEqual(res, case) res = await self.con.fetchval(''' SELECT $1::hstore AS result ''', (('foo', '2'), ('bar', '3'))) self.assertEqual(res, {'foo': '2', 'bar': '3'}) with self.assertRaisesRegex(asyncpg.DataError, 'null value not allowed'): await self.con.fetchval(''' SELECT $1::hstore AS result ''', {None: '1'}) await self.con.set_builtin_type_codec( 'my_dec_t', codec_name='decimal') res = await self.con.fetchval(''' SELECT $1::my_dec_t AS result ''', 44) self.assertEqual(res, 44) # Both my_dec_t and hstore are decoded in binary res = await self.con.fetchval(''' SELECT ($1::my_dec_t, 'a=>1'::hstore)::rec_t AS result ''', 44) self.assertEqual(res, (44, {'a': '1'})) # Now, declare only the text format for my_dec_t await self.con.reset_type_codec('my_dec_t') await self.con.set_builtin_type_codec( 'my_dec_t', codec_name='decimal', format='text') # This should fail, as there is no binary codec for # my_dec_t and text decoding of composites is not # implemented. with self.assertRaises(asyncpg.UnsupportedClientFeatureError): res = await self.con.fetchval(''' SELECT ($1::my_dec_t, 'a=>1'::hstore)::rec_t AS result ''', 44) finally: await self.con.execute(''' DROP TYPE rec_t; DROP EXTENSION hstore; DROP DOMAIN my_dec_t; ''') async def test_custom_codec_text(self): """Test encoding/decoding using a custom codec in text mode.""" await self.con.execute(''' CREATE EXTENSION IF NOT EXISTS hstore ''') def hstore_decoder(data): result = {} items = data.split(',') for item in items: k, _, v = item.partition('=>') result[k.strip('"')] = v.strip('"') return result def hstore_encoder(obj): return ','.join('{}=>{}'.format(k, v) for k, v in obj.items()) try: await self.con.set_type_codec('hstore', encoder=hstore_encoder, decoder=hstore_decoder) st = await self.con.prepare(''' SELECT $1::hstore AS result ''') res = await st.fetchrow({'ham': 'spam'}) res = res['result'] self.assertEqual(res, {'ham': 'spam'}) pt = st.get_parameters() self.assertTrue(isinstance(pt, tuple)) self.assertEqual(len(pt), 1) self.assertEqual(pt[0].name, 'hstore') self.assertEqual(pt[0].kind, 'scalar') self.assertEqual(pt[0].schema, 'public') at = st.get_attributes() self.assertTrue(isinstance(at, tuple)) self.assertEqual(len(at), 1) self.assertEqual(at[0].name, 'result') self.assertEqual(at[0].type, pt[0]) err = 'cannot use custom codec on type public._hstore' with self.assertRaisesRegex(asyncpg.InterfaceError, err): await self.con.set_type_codec('_hstore', encoder=hstore_encoder, decoder=hstore_decoder) finally: await self.con.execute(''' DROP EXTENSION hstore ''') async def test_custom_codec_binary(self): """Test encoding/decoding using a custom codec in binary mode.""" await self.con.execute(''' CREATE EXTENSION IF NOT EXISTS hstore ''') longstruct = struct.Struct('!L') ulong_unpack = lambda b: longstruct.unpack_from(b)[0] ulong_pack = longstruct.pack def hstore_decoder(data): result = {} n = ulong_unpack(data) view = memoryview(data) ptr = 4 for i in range(n): klen = ulong_unpack(view[ptr:ptr + 4]) ptr += 4 k = bytes(view[ptr:ptr + klen]).decode() ptr += klen vlen = ulong_unpack(view[ptr:ptr + 4]) ptr += 4 if vlen == -1: v = None else: v = bytes(view[ptr:ptr + vlen]).decode() ptr += vlen result[k] = v return result def hstore_encoder(obj): buffer = bytearray(ulong_pack(len(obj))) for k, v in obj.items(): kenc = k.encode() buffer += ulong_pack(len(kenc)) + kenc if v is None: buffer += b'\xFF\xFF\xFF\xFF' # -1 else: venc = v.encode() buffer += ulong_pack(len(venc)) + venc return buffer try: await self.con.set_type_codec('hstore', encoder=hstore_encoder, decoder=hstore_decoder, format='binary') st = await self.con.prepare(''' SELECT $1::hstore AS result ''') res = await st.fetchrow({'ham': 'spam'}) res = res['result'] self.assertEqual(res, {'ham': 'spam'}) pt = st.get_parameters() self.assertTrue(isinstance(pt, tuple)) self.assertEqual(len(pt), 1) self.assertEqual(pt[0].name, 'hstore') self.assertEqual(pt[0].kind, 'scalar') self.assertEqual(pt[0].schema, 'public') at = st.get_attributes() self.assertTrue(isinstance(at, tuple)) self.assertEqual(len(at), 1) self.assertEqual(at[0].name, 'result') self.assertEqual(at[0].type, pt[0]) finally: await self.con.execute(''' DROP EXTENSION hstore ''') async def test_custom_codec_on_domain(self): """Test encoding/decoding using a custom codec on a domain.""" await self.con.execute(''' CREATE DOMAIN custom_codec_t AS int ''') try: with self.assertRaisesRegex( asyncpg.UnsupportedClientFeatureError, 'custom codecs on domain types are not supported' ): await self.con.set_type_codec( 'custom_codec_t', encoder=lambda v: str(v), decoder=lambda v: int(v)) finally: await self.con.execute('DROP DOMAIN custom_codec_t') async def test_custom_codec_on_stdsql_types(self): types = [ 'smallint', 'int', 'integer', 'bigint', 'decimal', 'real', 'double precision', 'timestamp with timezone', 'time with timezone', 'timestamp without timezone', 'time without timezone', 'char', 'character', 'character varying', 'bit varying', 'CHARACTER VARYING' ] for t in types: with self.subTest(type=t): try: await self.con.set_type_codec( t, schema='pg_catalog', encoder=str, decoder=str, format='text' ) finally: await self.con.reset_type_codec(t, schema='pg_catalog') async def test_custom_codec_on_enum(self): """Test encoding/decoding using a custom codec on an enum.""" await self.con.execute(''' CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz') ''') try: await self.con.set_type_codec( 'custom_codec_t', encoder=lambda v: str(v).lstrip('enum :'), decoder=lambda v: 'enum: ' + str(v)) v = await self.con.fetchval('SELECT $1::custom_codec_t', 'foo') self.assertEqual(v, 'enum: foo') finally: await self.con.execute('DROP TYPE custom_codec_t') async def test_custom_codec_on_enum_array(self): """Test encoding/decoding using a custom codec on an enum array. Bug: https://github.com/MagicStack/asyncpg/issues/590 """ await self.con.execute(''' CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz') ''') try: await self.con.set_type_codec( 'custom_codec_t', encoder=lambda v: str(v).lstrip('enum :'), decoder=lambda v: 'enum: ' + str(v)) v = await self.con.fetchval( "SELECT ARRAY['foo', 'bar']::custom_codec_t[]") self.assertEqual(v, ['enum: foo', 'enum: bar']) v = await self.con.fetchval( 'SELECT ARRAY[$1]::custom_codec_t[]', 'foo') self.assertEqual(v, ['enum: foo']) v = await self.con.fetchval("SELECT 'foo'::custom_codec_t") self.assertEqual(v, 'enum: foo') finally: await self.con.execute('DROP TYPE custom_codec_t') async def test_custom_codec_override_binary(self): """Test overriding core codecs.""" import json conn = await self.connect() try: def _encoder(value): return json.dumps(value).encode('utf-8') def _decoder(value): return json.loads(value.decode('utf-8')) await conn.set_type_codec( 'json', encoder=_encoder, decoder=_decoder, schema='pg_catalog', format='binary' ) data = {'foo': 'bar', 'spam': 1} res = await conn.fetchval('SELECT $1::json', data) self.assertEqual(data, res) finally: await conn.close() async def test_custom_codec_override_text(self): """Test overriding core codecs.""" import json conn = await self.connect() try: def _encoder(value): return json.dumps(value) def _decoder(value): return json.loads(value) await conn.set_type_codec( 'json', encoder=_encoder, decoder=_decoder, schema='pg_catalog', format='text' ) data = {'foo': 'bar', 'spam': 1} res = await conn.fetchval('SELECT $1::json', data) self.assertEqual(data, res) res = await conn.fetchval('SELECT $1::json[]', [data]) self.assertEqual([data], res) await conn.execute('CREATE DOMAIN my_json AS json') res = await conn.fetchval('SELECT $1::my_json', data) self.assertEqual(data, res) def _encoder(value): return value def _decoder(value): return value await conn.set_type_codec( 'uuid', encoder=_encoder, decoder=_decoder, schema='pg_catalog', format='text' ) data = '14058ad9-0118-4b7e-ac15-01bc13e2ccd1' res = await conn.fetchval('SELECT $1::uuid', data) self.assertEqual(res, data) finally: await conn.execute('DROP DOMAIN IF EXISTS my_json') await conn.close() async def test_custom_codec_override_tuple(self): """Test overriding core codecs.""" cases = [ ('date', (3,), '2000-01-04'), ('date', (2**31 - 1,), 'infinity'), ('date', (-2**31,), '-infinity'), ('time', (60 * 10**6,), '00:01:00'), ('timetz', (60 * 10**6, 12600), '00:01:00-03:30'), ('timestamp', (60 * 10**6,), '2000-01-01 00:01:00'), ('timestamp', (2**63 - 1,), 'infinity'), ('timestamp', (-2**63,), '-infinity'), ('timestamptz', (60 * 10**6,), '1999-12-31 19:01:00', "tab.v AT TIME ZONE 'EST'"), ('timestamptz', (2**63 - 1,), 'infinity'), ('timestamptz', (-2**63,), '-infinity'), ('interval', (2, 3, 1), '2 mons 3 days 00:00:00.000001') ] conn = await self.connect() def _encoder(value): return tuple(value) def _decoder(value): return tuple(value) try: for (typename, data, expected_result, *extra) in cases: with self.subTest(type=typename): await self.con.execute( 'CREATE TABLE tab (v {})'.format(typename)) try: await conn.set_type_codec( typename, encoder=_encoder, decoder=_decoder, schema='pg_catalog', format='tuple' ) await conn.execute( 'INSERT INTO tab VALUES ($1)', data) res = await conn.fetchval('SELECT tab.v FROM tab') self.assertEqual(res, data) await conn.reset_type_codec( typename, schema='pg_catalog') if extra: val = extra[0] else: val = 'tab.v' res = await conn.fetchval( 'SELECT ({val})::text FROM tab'.format(val=val)) self.assertEqual(res, expected_result) finally: await self.con.execute('DROP TABLE tab') finally: await conn.close() async def test_custom_codec_composite_tuple(self): await self.con.execute(''' CREATE TYPE mycomplex AS (r float, i float); ''') try: await self.con.set_type_codec( 'mycomplex', encoder=lambda x: (x.real, x.imag), decoder=lambda t: complex(t[0], t[1]), format='tuple', ) num = complex('1+2j') res = await self.con.fetchval( 'SELECT $1::mycomplex', num, ) self.assertEqual(num, res) finally: await self.con.execute(''' DROP TYPE mycomplex; ''') async def test_custom_codec_composite_non_tuple(self): await self.con.execute(''' CREATE TYPE mycomplex AS (r float, i float); ''') try: with self.assertRaisesRegex( asyncpg.UnsupportedClientFeatureError, "only tuple-format codecs can be used on composite types", ): await self.con.set_type_codec( 'mycomplex', encoder=lambda x: (x.real, x.imag), decoder=lambda t: complex(t[0], t[1]), ) finally: await self.con.execute(''' DROP TYPE mycomplex; ''') async def test_timetz_encoding(self): try: async with self.con.transaction(): await self.con.execute("SET TIME ZONE 'America/Toronto'") # Check decoding: row = await self.con.fetchrow( 'SELECT extract(epoch from now())::float8 AS epoch, ' 'now()::date as date, now()::timetz as time') result = datetime.datetime.combine(row['date'], row['time']) expected = datetime.datetime.fromtimestamp(row['epoch'], tz=result.tzinfo) self.assertEqual(result, expected) # Check encoding: res = await self.con.fetchval( 'SELECT now() = ($1::date + $2::timetz)', row['date'], row['time']) self.assertTrue(res) finally: await self.con.execute('RESET ALL') async def test_composites_in_arrays(self): await self.con.execute(''' CREATE TYPE t AS (a text, b int); CREATE TABLE tab (d t[]); ''') try: await self.con.execute( 'INSERT INTO tab (d) VALUES ($1)', [('a', 1)]) r = await self.con.fetchval(''' SELECT d FROM tab ''') self.assertEqual(r, [('a', 1)]) finally: await self.con.execute(''' DROP TABLE tab; DROP TYPE t; ''') async def test_table_as_composite(self): await self.con.execute(''' CREATE TABLE tab (a text, b int); INSERT INTO tab VALUES ('1', 1); ''') try: r = await self.con.fetchrow(''' SELECT tab FROM tab ''') self.assertEqual(r, (('1', 1),)) finally: await self.con.execute(''' DROP TABLE tab; ''') async def test_relacl_array_type(self): await self.con.execute(r''' CREATE USER """u1'"; CREATE USER "{u2"; CREATE USER ",u3"; CREATE USER "u4}"; CREATE USER "u5"""; CREATE USER "u6\"""; CREATE USER "u7\"; CREATE USER norm1; CREATE USER norm2; CREATE TABLE t0 (); GRANT SELECT ON t0 TO norm1; CREATE TABLE t1 (); GRANT SELECT ON t1 TO """u1'"; CREATE TABLE t2 (); GRANT SELECT ON t2 TO "{u2"; CREATE TABLE t3 (); GRANT SELECT ON t3 TO ",u3"; CREATE TABLE t4 (); GRANT SELECT ON t4 TO "u4}"; CREATE TABLE t5 (); GRANT SELECT ON t5 TO "u5"""; CREATE TABLE t6 (); GRANT SELECT ON t6 TO "u6\"""; CREATE TABLE t7 (); GRANT SELECT ON t7 TO "u7\"; CREATE TABLE a1 (); GRANT SELECT ON a1 TO """u1'"; GRANT SELECT ON a1 TO "{u2"; GRANT SELECT ON a1 TO ",u3"; GRANT SELECT ON a1 TO "norm1"; GRANT SELECT ON a1 TO "u4}"; GRANT SELECT ON a1 TO "u5"""; GRANT SELECT ON a1 TO "u6\"""; GRANT SELECT ON a1 TO "u7\"; GRANT SELECT ON a1 TO "norm2"; CREATE TABLE a2 (); GRANT SELECT ON a2 TO """u1'" WITH GRANT OPTION; GRANT SELECT ON a2 TO "{u2" WITH GRANT OPTION; GRANT SELECT ON a2 TO ",u3" WITH GRANT OPTION; GRANT SELECT ON a2 TO "norm1" WITH GRANT OPTION; GRANT SELECT ON a2 TO "u4}" WITH GRANT OPTION; GRANT SELECT ON a2 TO "u5""" WITH GRANT OPTION; GRANT SELECT ON a2 TO "u6\""" WITH GRANT OPTION; GRANT SELECT ON a2 TO "u7\" WITH GRANT OPTION; SET SESSION AUTHORIZATION """u1'"; GRANT SELECT ON a2 TO "norm2"; SET SESSION AUTHORIZATION "{u2"; GRANT SELECT ON a2 TO "norm2"; SET SESSION AUTHORIZATION ",u3"; GRANT SELECT ON a2 TO "norm2"; SET SESSION AUTHORIZATION "u4}"; GRANT SELECT ON a2 TO "norm2"; SET SESSION AUTHORIZATION "u5"""; GRANT SELECT ON a2 TO "norm2"; SET SESSION AUTHORIZATION "u6\"""; GRANT SELECT ON a2 TO "norm2"; SET SESSION AUTHORIZATION "u7\"; GRANT SELECT ON a2 TO "norm2"; RESET SESSION AUTHORIZATION; ''') try: rows = await self.con.fetch(''' SELECT relacl, relacl::text[] AS chk, relacl::text[]::text AS text_ FROM pg_catalog.pg_class WHERE relacl IS NOT NULL ''') for row in rows: self.assertEqual(row['relacl'], row['chk'],) finally: await self.con.execute(r''' DROP TABLE t0; DROP TABLE t1; DROP TABLE t2; DROP TABLE t3; DROP TABLE t4; DROP TABLE t5; DROP TABLE t6; DROP TABLE t7; DROP TABLE a1; DROP TABLE a2; DROP USER """u1'"; DROP USER "{u2"; DROP USER ",u3"; DROP USER "u4}"; DROP USER "u5"""; DROP USER "u6\"""; DROP USER "u7\"; DROP USER norm1; DROP USER norm2; ''') async def test_enum(self): await self.con.execute(''' CREATE TYPE enum_t AS ENUM ('abc', 'def', 'ghi'); CREATE TABLE tab ( a text, b enum_t ); INSERT INTO tab (a, b) VALUES ('foo', 'abc'); INSERT INTO tab (a, b) VALUES ('bar', 'def'); ''') try: for i in range(10): r = await self.con.fetch(''' SELECT a, b FROM tab ORDER BY b ''') self.assertEqual(r, [('foo', 'abc'), ('bar', 'def')]) finally: await self.con.execute(''' DROP TABLE tab; DROP TYPE enum_t; ''') async def test_unknown_type_text_fallback(self): await self.con.execute(r'CREATE EXTENSION citext') await self.con.execute(r''' CREATE DOMAIN citext_dom AS citext ''') await self.con.execute(r''' CREATE TYPE citext_range AS RANGE (SUBTYPE = citext) ''') await self.con.execute(r''' CREATE TYPE citext_comp AS (t citext) ''') try: # Check that plain fallback works. result = await self.con.fetchval(''' SELECT $1::citext ''', 'citext') self.assertEqual(result, 'citext') # Check that domain fallback works. result = await self.con.fetchval(''' SELECT $1::citext_dom ''', 'citext') self.assertEqual(result, 'citext') # Check that array fallback works. cases = [ ['a', 'b'], [None, 'b'], [], [' a', ' b'], ['"a', r'\""'], [['"a', r'\""'], [',', '",']], ] for case in cases: result = await self.con.fetchval(''' SELECT $1::citext[] ''', case) self.assertEqual(result, case) # Text encoding of ranges and composite types # is not supported yet. with self.assertRaisesRegex( asyncpg.UnsupportedClientFeatureError, 'text encoding of range types is not supported'): await self.con.fetchval(''' SELECT $1::citext_range ''', ['a', 'z']) with self.assertRaisesRegex( asyncpg.UnsupportedClientFeatureError, 'text encoding of composite types is not supported'): await self.con.fetchval(''' SELECT $1::citext_comp ''', ('a',)) # Check that setting a custom codec clears the codec # cache properly and that subsequent queries work # as expected. await self.con.set_type_codec( 'citext', encoder=lambda d: d, decoder=lambda d: 'CI: ' + d) result = await self.con.fetchval(''' SELECT $1::citext[] ''', ['a', 'b']) self.assertEqual(result, ['CI: a', 'CI: b']) finally: await self.con.execute(r'DROP TYPE citext_comp') await self.con.execute(r'DROP TYPE citext_range') await self.con.execute(r'DROP TYPE citext_dom') await self.con.execute(r'DROP EXTENSION citext') async def test_enum_in_array(self): await self.con.execute(''' CREATE TYPE enum_t AS ENUM ('abc', 'def', 'ghi'); ''') try: result = await self.con.fetchrow('''SELECT $1::enum_t[];''', ['abc']) self.assertEqual(result, (['abc'],)) result = await self.con.fetchrow('''SELECT ARRAY[$1::enum_t];''', 'abc') self.assertEqual(result, (['abc'],)) finally: await self.con.execute(''' DROP TYPE enum_t; ''') async def test_enum_and_range(self): await self.con.execute(''' CREATE TYPE enum_t AS ENUM ('abc', 'def', 'ghi'); CREATE TABLE testtab ( a int4range, b enum_t ); INSERT INTO testtab VALUES ( '[10, 20)', 'abc' ); ''') try: result = await self.con.fetchrow(''' SELECT testtab.a FROM testtab WHERE testtab.b = $1 ''', 'abc') self.assertEqual(result, (asyncpg.Range(10, 20),)) finally: await self.con.execute(''' DROP TABLE testtab; DROP TYPE enum_t; ''') async def test_enum_in_composite(self): await self.con.execute(''' CREATE TYPE enum_t AS ENUM ('abc', 'def', 'ghi'); CREATE TYPE composite_w_enum AS (a int, b enum_t); ''') try: result = await self.con.fetchval(''' SELECT ROW(1, 'def'::enum_t)::composite_w_enum ''') self.assertEqual(set(result.items()), {('a', 1), ('b', 'def')}) finally: await self.con.execute(''' DROP TYPE composite_w_enum; DROP TYPE enum_t; ''') async def test_enum_function_return(self): await self.con.execute(''' CREATE TYPE enum_t AS ENUM ('abc', 'def', 'ghi'); CREATE FUNCTION return_enum() RETURNS enum_t LANGUAGE plpgsql AS $$ BEGIN RETURN 'abc'::enum_t; END; $$; ''') try: result = await self.con.fetchval('''SELECT return_enum()''') self.assertEqual(result, 'abc') finally: await self.con.execute(''' DROP FUNCTION return_enum(); DROP TYPE enum_t; ''') async def test_no_result(self): st = await self.con.prepare('rollback') self.assertTupleEqual(st.get_attributes(), ()) async def test_array_with_custom_json_text_codec(self): import json await self.con.execute('CREATE TABLE tab (id serial, val json[]);') insert_sql = 'INSERT INTO tab (val) VALUES (cast($1 AS json[]));' query_sql = 'SELECT val FROM tab ORDER BY id DESC;' try: for custom_codec in [False, True]: if custom_codec: await self.con.set_type_codec( 'json', encoder=lambda v: v, decoder=json.loads, schema="pg_catalog", ) for val in ['"null"', '22', 'null', '[2]', '{"a": null}']: await self.con.execute(insert_sql, [val]) result = await self.con.fetchval(query_sql) if custom_codec: self.assertEqual(result, [json.loads(val)]) else: self.assertEqual(result, [val]) await self.con.execute(insert_sql, [None]) result = await self.con.fetchval(query_sql) self.assertEqual(result, [None]) await self.con.execute(insert_sql, None) result = await self.con.fetchval(query_sql) self.assertEqual(result, None) finally: await self.con.execute(''' DROP TABLE tab; ''') @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') class TestCodecsLargeOIDs(tb.ConnectedTestCase): LARGE_OID = 2147483648 @classmethod def setup_cluster(cls): cls.cluster = cls.new_cluster(pg_cluster.TempCluster) cls.cluster.reset_wal(oid=cls.LARGE_OID) cls.start_cluster(cls.cluster) async def test_custom_codec_large_oid(self): await self.con.execute('CREATE DOMAIN test_domain_t AS int') try: oid = await self.con.fetchval(''' SELECT oid FROM pg_type WHERE typname = 'test_domain_t' ''') expected_oid = self.LARGE_OID if self.server_version >= (11, 0): # PostgreSQL 11 automatically creates a domain array type # _before_ the domain type, so the expected OID is # off by one. expected_oid += 1 self.assertEqual(oid, expected_oid) # Test that introspection handles large OIDs v = await self.con.fetchval('SELECT $1::test_domain_t', 10) self.assertEqual(v, 10) finally: await self.con.execute('DROP DOMAIN test_domain_t') ================================================ FILE: tests/test_connect.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import contextlib import gc import ipaddress import os import pathlib import platform import shutil import socket import ssl import stat import tempfile import textwrap import unittest import unittest.mock import urllib.parse import warnings import weakref import distro import asyncpg from asyncpg import _testbase as tb from asyncpg import connection as pg_connection from asyncpg import connect_utils from asyncpg import cluster as pg_cluster from asyncpg import exceptions from asyncpg.connect_utils import SSLMode from asyncpg.serverversion import split_server_version_string _system = platform.uname().system CERTS = os.path.join(os.path.dirname(__file__), 'certs') SSL_CA_CERT_FILE = os.path.join(CERTS, 'ca.cert.pem') SSL_CA_CRL_FILE = os.path.join(CERTS, 'ca.crl.pem') SSL_CERT_FILE = os.path.join(CERTS, 'server.cert.pem') SSL_KEY_FILE = os.path.join(CERTS, 'server.key.pem') CLIENT_CA_CERT_FILE = os.path.join(CERTS, 'client_ca.cert.pem') CLIENT_SSL_CERT_FILE = os.path.join(CERTS, 'client.cert.pem') CLIENT_SSL_KEY_FILE = os.path.join(CERTS, 'client.key.pem') CLIENT_SSL_PROTECTED_KEY_FILE = os.path.join(CERTS, 'client.key.protected.pem') if _system == 'Windows': DEFAULT_GSSLIB = 'sspi' OTHER_GSSLIB = 'gssapi' else: DEFAULT_GSSLIB = 'gssapi' OTHER_GSSLIB = 'sspi' @contextlib.contextmanager def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False): with tempfile.TemporaryDirectory() as temp_dir: home = pathlib.Path(temp_dir) pg_home = home / '.postgresql' pg_home.mkdir() if ca: shutil.copyfile(SSL_CA_CERT_FILE, pg_home / 'root.crt') if crl: shutil.copyfile(SSL_CA_CRL_FILE, pg_home / 'root.crl') if client: shutil.copyfile(CLIENT_SSL_CERT_FILE, pg_home / 'postgresql.crt') if protected: shutil.copyfile( CLIENT_SSL_PROTECTED_KEY_FILE, pg_home / 'postgresql.key' ) else: shutil.copyfile( CLIENT_SSL_KEY_FILE, pg_home / 'postgresql.key' ) with unittest.mock.patch( 'pathlib.Path.home', unittest.mock.Mock(return_value=home) ): yield @contextlib.contextmanager def mock_no_home_dir(): with unittest.mock.patch( 'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError) ): yield @contextlib.contextmanager def mock_dev_null_home_dir(): with unittest.mock.patch( 'pathlib.Path.home', unittest.mock.Mock(return_value=pathlib.Path('/dev/null')), ): yield class TestSettings(tb.ConnectedTestCase): async def test_get_settings_01(self): self.assertEqual( self.con.get_settings().client_encoding, 'UTF8') async def test_server_version_01(self): version = self.con.get_server_version() version_num = await self.con.fetchval("SELECT current_setting($1)", 'server_version_num', column=0) ver_maj = int(version_num[:-4]) ver_min = int(version_num[-4:-2]) ver_fix = int(version_num[-2:]) self.assertEqual(version[:3], (ver_maj, ver_min, ver_fix)) def test_server_version_02(self): versions = [ ("9.2", (9, 2, 0, 'final', 0),), ("Postgres-XL 9.2.1", (9, 2, 1, 'final', 0),), ("9.4beta1", (9, 4, 0, 'beta', 1),), ("10devel", (10, 0, 0, 'devel', 0),), ("10beta2", (10, 0, 0, 'beta', 2),), # For PostgreSQL versions >=10 we always # set version.minor to 0. ("10.1", (10, 0, 1, 'final', 0),), ("11.1.2", (11, 0, 1, 'final', 0),), ("PostgreSQL 10.1 (Debian 10.1-3)", (10, 0, 1, 'final', 0),), ("PostgreSQL 11.2-YB-2.7.1.1-b0 on x86_64-pc-linux-gnu, " "compiled by gcc (Homebrew gcc 5.5.0_4) 5.5.0, 64-bit", (11, 0, 2, "final", 0),), ] for version, expected in versions: result = split_server_version_string(version) self.assertEqual(expected, result) CORRECT_PASSWORD = 'correct\u1680password' class BaseTestAuthentication(tb.ConnectedTestCase): USERS = [] def setUp(self): super().setUp() if not self.cluster.is_managed(): self.skipTest('unmanaged cluster') self.cluster.reset_hba() create_script = [] for username, method, password in self.USERS: if method == 'scram-sha-256' and self.server_version.major < 10: continue # if this is a SCRAM password, we need to set the encryption method # to "scram-sha-256" in order to properly hash the password if method == 'scram-sha-256': create_script.append( "SET password_encryption = 'scram-sha-256';" ) create_script.append( 'CREATE ROLE "{}" WITH LOGIN{};'.format( username, f' PASSWORD E{(password or "")!r}' ) ) # to be courteous to the MD5 test, revert back to MD5 after the # scram-sha-256 password is set if method == 'scram-sha-256': create_script.append( "SET password_encryption = 'md5';" ) if _system != 'Windows' and method != 'gss': self.cluster.add_hba_entry( type='local', database='postgres', user=username, auth_method=method) self.cluster.add_hba_entry( type='host', address=ipaddress.ip_network('127.0.0.0/24'), database='postgres', user=username, auth_method=method) self.cluster.add_hba_entry( type='host', address=ipaddress.ip_network('::1/128'), database='postgres', user=username, auth_method=method) # Put hba changes into effect self.cluster.reload() create_script = '\n'.join(create_script) self.loop.run_until_complete(self.con.execute(create_script)) def tearDown(self): # Reset cluster's pg_hba.conf since we've meddled with it self.cluster.trust_local_connections() drop_script = [] for username, method, _ in self.USERS: if method == 'scram-sha-256' and self.server_version.major < 10: continue drop_script.append('DROP ROLE "{}";'.format(username)) drop_script = '\n'.join(drop_script) self.loop.run_until_complete(self.con.execute(drop_script)) super().tearDown() class TestAuthentication(BaseTestAuthentication): USERS = [ ('trust_user', 'trust', None), ('reject_user', 'reject', None), ('scram_sha_256_user', 'scram-sha-256', CORRECT_PASSWORD), ('md5_user', 'md5', CORRECT_PASSWORD), ('password_user', 'password', CORRECT_PASSWORD), ] async def _try_connect(self, **kwargs): # On Windows the server sometimes just closes # the connection sooner than we receive the # actual error. if _system == 'Windows': for tried in range(3): try: return await self.connect(**kwargs) except asyncpg.ConnectionDoesNotExistError: pass return await self.connect(**kwargs) async def test_auth_bad_user(self): with self.assertRaises( asyncpg.InvalidAuthorizationSpecificationError): await self._try_connect(user='__nonexistent__') async def test_auth_trust(self): conn = await self.connect(user='trust_user') await conn.close() async def test_auth_reject(self): with self.assertRaisesRegex( asyncpg.InvalidAuthorizationSpecificationError, 'pg_hba.conf rejects connection'): await self._try_connect(user='reject_user') async def test_auth_password_cleartext(self): conn = await self.connect( user='password_user', password=CORRECT_PASSWORD) await conn.close() with self.assertRaisesRegex( asyncpg.InvalidPasswordError, 'password authentication failed for user "password_user"'): await self._try_connect( user='password_user', password='wrongpassword') async def test_auth_password_cleartext_callable(self): def get_correctpassword(): return CORRECT_PASSWORD def get_wrongpassword(): return 'wrongpassword' conn = await self.connect( user='password_user', password=get_correctpassword) await conn.close() with self.assertRaisesRegex( asyncpg.InvalidPasswordError, 'password authentication failed for user "password_user"'): await self._try_connect( user='password_user', password=get_wrongpassword) async def test_auth_password_cleartext_callable_coroutine(self): async def get_correctpassword(): return CORRECT_PASSWORD async def get_wrongpassword(): return 'wrongpassword' conn = await self.connect( user='password_user', password=get_correctpassword) await conn.close() with self.assertRaisesRegex( asyncpg.InvalidPasswordError, 'password authentication failed for user "password_user"'): await self._try_connect( user='password_user', password=get_wrongpassword) async def test_auth_password_cleartext_callable_awaitable(self): async def get_correctpassword(): return CORRECT_PASSWORD async def get_wrongpassword(): return 'wrongpassword' conn = await self.connect( user='password_user', password=lambda: get_correctpassword()) await conn.close() with self.assertRaisesRegex( asyncpg.InvalidPasswordError, 'password authentication failed for user "password_user"'): await self._try_connect( user='password_user', password=lambda: get_wrongpassword()) async def test_auth_password_md5(self): conn = await self.connect( user='md5_user', password=CORRECT_PASSWORD) await conn.close() with self.assertRaisesRegex( asyncpg.InvalidPasswordError, 'password authentication failed for user "md5_user"'): await self._try_connect( user='md5_user', password='wrongpassword') async def test_auth_password_scram_sha_256(self): # scram is only supported in PostgreSQL 10 and above if self.server_version.major < 10: return conn = await self.connect( user='scram_sha_256_user', password=CORRECT_PASSWORD) await conn.close() with self.assertRaisesRegex( asyncpg.InvalidPasswordError, 'password authentication failed for user "scram_sha_256_user"' ): await self._try_connect( user='scram_sha_256_user', password='wrongpassword') # various SASL prep tests # first ensure that password are being hashed for SCRAM-SHA-256 await self.con.execute("SET password_encryption = 'scram-sha-256';") alter_password = "ALTER ROLE scram_sha_256_user PASSWORD E{!r};" passwords = [ 'nonascii\u1680space', # C.1.2 'common\u1806nothing', # B.1 'ab\ufb01c', # normalization 'ab\u007fc', # C.2.1 'ab\u206ac', # C.2.2, C.6 'ab\ue000c', # C.3, C.5 'ab\ufdd0c', # C.4 'ab\u2ff0c', # C.7 'ab\u2000c', # C.8 'ab\ue0001', # C.9 ] # ensure the passwords that go through SASLprep work for password in passwords: # update the password await self.con.execute(alter_password.format(password)) # test to see that passwords are properly SASL prepped conn = await self.connect( user='scram_sha_256_user', password=password) await conn.close() alter_password = \ f"ALTER ROLE scram_sha_256_user PASSWORD E{CORRECT_PASSWORD!r};" await self.con.execute(alter_password) await self.con.execute("SET password_encryption = 'md5';") @unittest.mock.patch('hashlib.md5', side_effect=ValueError("no md5")) async def test_auth_md5_unsupported(self, _): with self.assertRaisesRegex( exceptions.InternalClientError, ".*no md5.*", ): await self.connect(user='md5_user', password=CORRECT_PASSWORD) @unittest.skipIf( distro.id() == "alpine", "Alpine Linux ships PostgreSQL without GSS auth support", ) class TestGssAuthentication(BaseTestAuthentication): @classmethod def setUpClass(cls): try: from k5test.realm import K5Realm except ModuleNotFoundError: raise unittest.SkipTest('k5test not installed') cls.realm = K5Realm() cls.addClassCleanup(cls.realm.stop) # Setup environment before starting the cluster. patch = unittest.mock.patch.dict(os.environ, cls.realm.env) patch.start() cls.addClassCleanup(patch.stop) # Add credentials. cls.realm.addprinc('postgres/localhost') cls.realm.extract_keytab('postgres/localhost', cls.realm.keytab) cls.USERS = [ (cls.realm.user_princ, 'gss', None), (f'wrong-{cls.realm.user_princ}', 'gss', None), ] super().setUpClass() cls.cluster.override_connection_spec(host='localhost') @classmethod def get_server_settings(cls): settings = super().get_server_settings() settings['krb_server_keyfile'] = f'FILE:{cls.realm.keytab}' return settings @classmethod def setup_cluster(cls): cls.cluster = cls.new_cluster(pg_cluster.TempCluster) cls.start_cluster( cls.cluster, server_settings=cls.get_server_settings()) async def test_auth_gssapi_ok(self): conn = await self.connect(user=self.realm.user_princ) await conn.close() async def test_auth_gssapi_bad_srvname(self): # Service name mismatch. with self.assertRaisesRegex( exceptions.InternalClientError, 'Server .* not found' ): await self.connect(user=self.realm.user_princ, krbsrvname='wrong') async def test_auth_gssapi_bad_user(self): # Credentials mismatch. with self.assertRaisesRegex( exceptions.InvalidAuthorizationSpecificationError, 'GSSAPI authentication failed for user' ): await self.connect(user=f'wrong-{self.realm.user_princ}') @unittest.skipIf(_system != 'Windows', 'SSPI is only available on Windows') class TestSspiAuthentication(BaseTestAuthentication): @classmethod def setUpClass(cls): cls.username = f'{os.getlogin()}@{socket.gethostname()}' cls.USERS = [ (cls.username, 'sspi', None), (f'wrong-{cls.username}', 'sspi', None), ] super().setUpClass() async def test_auth_sspi(self): conn = await self.connect(user=self.username) await conn.close() # Credentials mismatch. with self.assertRaisesRegex( exceptions.InvalidAuthorizationSpecificationError, 'SSPI authentication failed for user' ): await self.connect(user=f'wrong-{self.username}') class TestConnectParams(tb.TestCase): TESTS = [ { 'name': 'all_env_default_ssl', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', 'PGPASSWORD': 'passw', 'PGHOST': 'host', 'PGPORT': '123' }, 'result': ([('host', 123)], { 'user': 'user', 'password': 'passw', 'database': 'testdb', 'ssl': True, 'sslmode': SSLMode.prefer, 'target_session_attrs': 'any'}) }, { 'name': 'params_override_env', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', 'PGPASSWORD': 'passw', 'PGHOST': 'host', 'PGPORT': '123' }, 'host': 'host2', 'port': '456', 'user': 'user2', 'password': 'passw2', 'database': 'db2', 'result': ([('host2', 456)], { 'user': 'user2', 'password': 'passw2', 'database': 'db2', 'target_session_attrs': 'any'}) }, { 'name': 'params_override_env_and_dsn', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', 'PGPASSWORD': 'passw', 'PGHOST': 'host', 'PGPORT': '123', 'PGSSLMODE': 'allow' }, 'dsn': 'postgres://user3:123123@localhost/abcdef', 'host': 'host2', 'port': '456', 'user': 'user2', 'password': 'passw2', 'database': 'db2', 'ssl': False, 'result': ([('host2', 456)], { 'user': 'user2', 'password': 'passw2', 'database': 'db2', 'sslmode': SSLMode.disable, 'ssl': False, 'target_session_attrs': 'any'}) }, { 'name': 'dsn_overrides_env_partially', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', 'PGPASSWORD': 'passw', 'PGHOST': 'host', 'PGPORT': '123', 'PGSSLMODE': 'allow' }, 'dsn': 'postgres://user3:123123@localhost:5555/abcdef', 'result': ([('localhost', 5555)], { 'user': 'user3', 'password': '123123', 'database': 'abcdef', 'ssl': True, 'sslmode': SSLMode.allow, 'target_session_attrs': 'any'}) }, { 'name': 'params_override_env_and_dsn_ssl_prefer', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', 'PGPASSWORD': 'passw', 'PGHOST': 'host', 'PGPORT': '123', 'PGSSLMODE': 'prefer' }, 'dsn': 'postgres://user3:123123@localhost/abcdef', 'host': 'host2', 'port': '456', 'user': 'user2', 'password': 'passw2', 'database': 'db2', 'ssl': False, 'result': ([('host2', 456)], { 'user': 'user2', 'password': 'passw2', 'database': 'db2', 'sslmode': SSLMode.disable, 'ssl': False, 'target_session_attrs': 'any'}) }, { 'name': 'params_ssl_negotiation_dsn', 'env': { 'PGSSLNEGOTIATION': 'postgres' }, 'dsn': 'postgres://u:p@localhost/d?sslnegotiation=direct', 'result': ([('localhost', 5432)], { 'user': 'u', 'password': 'p', 'database': 'd', 'ssl_negotiation': 'direct', 'target_session_attrs': 'any', }) }, { 'name': 'params_ssl_negotiation_env', 'env': { 'PGSSLNEGOTIATION': 'direct' }, 'dsn': 'postgres://u:p@localhost/d', 'result': ([('localhost', 5432)], { 'user': 'u', 'password': 'p', 'database': 'd', 'ssl_negotiation': 'direct', 'target_session_attrs': 'any', }) }, { 'name': 'params_ssl_negotiation_params', 'env': { 'PGSSLNEGOTIATION': 'direct' }, 'dsn': 'postgres://u:p@localhost/d', 'direct_tls': False, 'result': ([('localhost', 5432)], { 'user': 'u', 'password': 'p', 'database': 'd', 'ssl_negotiation': 'postgres', 'target_session_attrs': 'any', }) }, { 'name': 'dsn_overrides_env_partially_ssl_prefer', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', 'PGPASSWORD': 'passw', 'PGHOST': 'host', 'PGPORT': '123', 'PGSSLMODE': 'prefer' }, 'dsn': 'postgres://user3:123123@localhost:5555/abcdef', 'result': ([('localhost', 5555)], { 'user': 'user3', 'password': '123123', 'database': 'abcdef', 'ssl': True, 'sslmode': SSLMode.prefer, 'target_session_attrs': 'any'}) }, { 'name': 'dsn_only', 'dsn': 'postgres://user3:123123@localhost:5555/abcdef', 'result': ([('localhost', 5555)], { 'user': 'user3', 'password': '123123', 'database': 'abcdef', 'target_session_attrs': 'any'}) }, { 'name': 'dsn_only_multi_host', 'dsn': 'postgresql://user@host1,host2/db', 'result': ([('host1', 5432), ('host2', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', }) }, { 'name': 'dsn_only_multi_host_and_port', 'dsn': 'postgresql://user@host1:1111,host2:2222/db', 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', }) }, { 'name': 'target_session_attrs', 'dsn': 'postgresql://user@host1:1111,host2:2222/db' '?target_session_attrs=read-only', 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'read-only', }) }, { 'name': 'target_session_attrs_2', 'dsn': 'postgresql://user@host1:1111,host2:2222/db' '?target_session_attrs=read-only', 'target_session_attrs': 'read-write', 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'read-write', }) }, { 'name': 'target_session_attrs_3', 'dsn': 'postgresql://user@host1:1111,host2:2222/db', 'env': { 'PGTARGETSESSIONATTRS': 'read-only', }, 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'read-only', }) }, { 'name': 'krbsrvname', 'dsn': 'postgresql://user@host/db?krbsrvname=srv_qs', 'env': { 'PGKRBSRVNAME': 'srv_env', }, 'result': ([('host', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', 'krbsrvname': 'srv_qs', }) }, { 'name': 'krbsrvname_2', 'dsn': 'postgresql://user@host/db?krbsrvname=srv_qs', 'krbsrvname': 'srv_kws', 'env': { 'PGKRBSRVNAME': 'srv_env', }, 'result': ([('host', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', 'krbsrvname': 'srv_kws', }) }, { 'name': 'krbsrvname_3', 'dsn': 'postgresql://user@host/db', 'env': { 'PGKRBSRVNAME': 'srv_env', }, 'result': ([('host', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', 'krbsrvname': 'srv_env', }) }, { 'name': 'gsslib', 'dsn': f'postgresql://user@host/db?gsslib={OTHER_GSSLIB}', 'env': { 'PGGSSLIB': 'ignored', }, 'result': ([('host', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', 'gsslib': OTHER_GSSLIB, }) }, { 'name': 'gsslib_2', 'dsn': 'postgresql://user@host/db?gsslib=ignored', 'gsslib': OTHER_GSSLIB, 'env': { 'PGGSSLIB': 'ignored', }, 'result': ([('host', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', 'gsslib': OTHER_GSSLIB, }) }, { 'name': 'gsslib_3', 'dsn': 'postgresql://user@host/db', 'env': { 'PGGSSLIB': OTHER_GSSLIB, }, 'result': ([('host', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', 'gsslib': OTHER_GSSLIB, }) }, { 'name': 'gsslib_4', 'dsn': 'postgresql://user@host/db', 'result': ([('host', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', 'gsslib': DEFAULT_GSSLIB, }) }, { 'name': 'gsslib_5', 'dsn': 'postgresql://user@host/db?gsslib=invalid', 'error': ( exceptions.ClientConfigurationError, "gsslib parameter must be either 'gssapi' or 'sspi'" ), }, # broken by https://github.com/python/cpython/pull/129418 # { # 'name': 'dsn_ipv6_multi_host', # 'dsn': 'postgresql://user@[2001:db8::1234%25eth0],[::1]/db', # 'result': ([('2001:db8::1234%eth0', 5432), ('::1', 5432)], { # 'database': 'db', # 'user': 'user', # 'target_session_attrs': 'any', # }) # }, # { # 'name': 'dsn_ipv6_multi_host_port', # 'dsn': 'postgresql://user@[2001:db8::1234]:1111,[::1]:2222/db', # 'result': ([('2001:db8::1234', 1111), ('::1', 2222)], { # 'database': 'db', # 'user': 'user', # 'target_session_attrs': 'any', # }) # }, { 'name': 'dsn_ipv6_multi_host_query_part', 'dsn': 'postgresql:///db?user=user&host=[2001:db8::1234],[::1]', 'result': ([('2001:db8::1234', 5432), ('::1', 5432)], { 'database': 'db', 'user': 'user', 'target_session_attrs': 'any', }) }, { 'name': 'dsn_combines_env_multi_host', 'env': { 'PGHOST': 'host1:1111,host2:2222', 'PGUSER': 'foo', }, 'dsn': 'postgresql:///db', 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'foo', 'target_session_attrs': 'any', }) }, { 'name': 'dsn_multi_host_combines_env', 'env': { 'PGUSER': 'foo', }, 'dsn': 'postgresql:///db?host=host1:1111,host2:2222', 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', 'user': 'foo', 'target_session_attrs': 'any', }) }, { 'name': 'params_multi_host_dsn_env_mix', 'env': { 'PGUSER': 'foo', }, 'dsn': 'postgresql:///db', 'host': ['host1', 'host2'], 'result': ([('host1', 5432), ('host2', 5432)], { 'database': 'db', 'user': 'foo', 'target_session_attrs': 'any', }) }, { 'name': 'params_multi_host_dsn_env_mix_tuple', 'env': { 'PGUSER': 'foo', }, 'dsn': 'postgresql:///db', 'host': ('host1', 'host2'), 'result': ([('host1', 5432), ('host2', 5432)], { 'database': 'db', 'user': 'foo', 'target_session_attrs': 'any', }) }, { 'name': 'params_combine_dsn_settings_override_and_ssl', 'dsn': 'postgresql://user3:123123@localhost:5555/' 'abcdef?param=sss¶m=123&host=testhost&user=testuser' '&port=2222&database=testdb&sslmode=require', 'host': '127.0.0.1', 'port': '888', 'user': 'me', 'password': 'ask', 'database': 'db', 'result': ([('127.0.0.1', 888)], { 'server_settings': {'param': '123'}, 'user': 'me', 'password': 'ask', 'database': 'db', 'ssl': True, 'sslmode': SSLMode.require, 'target_session_attrs': 'any'}) }, { 'name': 'params_settings_and_ssl_override_dsn', 'dsn': 'postgresql://user3:123123@localhost:5555/' 'abcdef?param=sss¶m=123&host=testhost&user=testuser' '&port=2222&database=testdb&sslmode=disable', 'host': '127.0.0.1', 'port': '888', 'user': 'me', 'password': 'ask', 'database': 'db', 'server_settings': {'aa': 'bb'}, 'ssl': True, 'result': ([('127.0.0.1', 888)], { 'server_settings': {'aa': 'bb', 'param': '123'}, 'user': 'me', 'password': 'ask', 'database': 'db', 'sslmode': SSLMode.verify_full, 'ssl': True, 'target_session_attrs': 'any'}) }, { 'name': 'dsn_only_unix', 'dsn': 'postgresql:///dbname?host=/unix_sock/test&user=spam', 'result': ([os.path.join('/unix_sock/test', '.s.PGSQL.5432')], { 'user': 'spam', 'database': 'dbname', 'target_session_attrs': 'any'}) }, { 'name': 'dsn_only_quoted', 'dsn': 'postgresql://us%40r:p%40ss@h%40st1,h%40st2:543%33/d%62', 'result': ( [('h@st1', 5432), ('h@st2', 5433)], { 'user': 'us@r', 'password': 'p@ss', 'database': 'db', 'target_session_attrs': 'any', } ) }, { 'name': 'dsn_only_unquoted_host', 'dsn': 'postgresql://user:p@ss@host/db', 'result': ( [('ss@host', 5432)], { 'user': 'user', 'password': 'p', 'database': 'db', 'target_session_attrs': 'any', } ) }, { 'name': 'dsn_only_quoted_params', 'dsn': 'postgresql:///d%62?user=us%40r&host=h%40st&port=543%33', 'result': ( [('h@st', 5433)], { 'user': 'us@r', 'database': 'db', 'target_session_attrs': 'any', } ) }, { 'name': 'dsn_only_illegal_protocol', 'dsn': 'pq:///dbname?host=/unix_sock/test&user=spam', 'error': (ValueError, 'invalid DSN') }, { 'name': 'dsn_params_ports_mismatch_dsn_multi_hosts', 'dsn': 'postgresql://host1,host2,host3/db', 'port': [111, 222], 'error': ( exceptions.InterfaceError, 'could not match 2 port numbers to 3 hosts' ) }, { 'name': 'dsn_only_quoted_unix_host_port_in_params', 'dsn': 'postgres://user@?port=56226&host=%2Ftmp', 'result': ( [os.path.join('/tmp', '.s.PGSQL.56226')], { 'user': 'user', 'database': 'user', 'sslmode': SSLMode.disable, 'ssl': None, 'target_session_attrs': 'any', } ) }, { 'name': 'dsn_only_cloudsql', 'dsn': 'postgres:///db?host=/cloudsql/' 'project:region:instance-name&user=spam', 'result': ( [os.path.join( '/cloudsql/project:region:instance-name', '.s.PGSQL.5432' )], { 'user': 'spam', 'database': 'db', 'target_session_attrs': 'any', } ) }, { 'name': 'dsn_only_cloudsql_unix_and_tcp', 'dsn': 'postgres:///db?host=127.0.0.1:5432,/cloudsql/' 'project:region:instance-name,localhost:5433&user=spam', 'result': ( [ ('127.0.0.1', 5432), os.path.join( '/cloudsql/project:region:instance-name', '.s.PGSQL.5432' ), ('localhost', 5433) ], { 'user': 'spam', 'database': 'db', 'ssl': True, 'sslmode': SSLMode.prefer, 'target_session_attrs': 'any', } ) }, { 'name': 'multi_host_single_port', 'dsn': 'postgres:///postgres?host=127.0.0.1,127.0.0.2&port=5432' '&user=postgres', 'result': ( [ ('127.0.0.1', 5432), ('127.0.0.2', 5432) ], { 'user': 'postgres', 'database': 'postgres', 'target_session_attrs': 'any', } ) }, ] @contextlib.contextmanager def environ(self, **kwargs): old_vals = {} for key in kwargs: if key in os.environ: old_vals[key] = os.environ[key] for key, val in kwargs.items(): if val is None: if key in os.environ: del os.environ[key] else: os.environ[key] = val try: yield finally: for key in kwargs: if key in os.environ: del os.environ[key] for key, val in old_vals.items(): os.environ[key] = val def run_testcase(self, testcase): env = testcase.get('env', {}) test_env = {'PGHOST': None, 'PGPORT': None, 'PGUSER': None, 'PGPASSWORD': None, 'PGDATABASE': None, 'PGSSLMODE': None, 'PGSERVICE': None, } test_env.update(env) dsn = testcase.get('dsn') user = testcase.get('user') port = testcase.get('port') host = testcase.get('host') password = testcase.get('password') passfile = testcase.get('passfile') database = testcase.get('database') sslmode = testcase.get('ssl') direct_tls = testcase.get('direct_tls') server_settings = testcase.get('server_settings') target_session_attrs = testcase.get('target_session_attrs') krbsrvname = testcase.get('krbsrvname') gsslib = testcase.get('gsslib') service = testcase.get('service') servicefile = testcase.get('servicefile') expected = testcase.get('result') expected_error = testcase.get('error') if expected is None and expected_error is None: raise RuntimeError( 'invalid test case: either "result" or "error" key ' 'has to be specified') if expected is not None and expected_error is not None: raise RuntimeError( 'invalid test case: either "result" or "error" key ' 'has to be specified, got both') with contextlib.ExitStack() as es: es.enter_context(self.subTest(dsn=dsn, env=env)) es.enter_context(self.environ(**test_env)) if expected_error: es.enter_context(self.assertRaisesRegex(*expected_error)) addrs, params = connect_utils._parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, database=database, ssl=sslmode, direct_tls=direct_tls, server_settings=server_settings, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib, service=service, servicefile=servicefile) params = { k: v for k, v in params._asdict().items() if v is not None or (expected is not None and k in expected[1]) } if isinstance(params.get('ssl'), ssl.SSLContext): params['ssl'] = True result = (addrs, params) if expected is not None: if 'ssl' not in expected[1]: # Avoid the hassle of specifying the default SSL mode # unless explicitly tested for. params.pop('ssl', None) params.pop('sslmode', None) if 'direct_tls' not in expected[1]: # Avoid the hassle of specifying direct_tls # unless explicitly tested for params.pop('direct_tls', False) if 'ssl_negotiation' not in expected[1]: # Avoid the hassle of specifying sslnegotiation # unless explicitly tested for params.pop('ssl_negotiation', False) if 'gsslib' not in expected[1]: # Avoid the hassle of specifying gsslib # unless explicitly tested for params.pop('gsslib', None) self.assertEqual(expected, result, 'Testcase: {}'.format(testcase)) def test_test_connect_params_environ(self): self.assertNotIn('AAAAAAAAAA123', os.environ) self.assertNotIn('AAAAAAAAAA456', os.environ) self.assertNotIn('AAAAAAAAAA789', os.environ) try: os.environ['AAAAAAAAAA456'] = '123' os.environ['AAAAAAAAAA789'] = '123' with self.environ(AAAAAAAAAA123='1', AAAAAAAAAA456='2', AAAAAAAAAA789=None): self.assertEqual(os.environ['AAAAAAAAAA123'], '1') self.assertEqual(os.environ['AAAAAAAAAA456'], '2') self.assertNotIn('AAAAAAAAAA789', os.environ) self.assertNotIn('AAAAAAAAAA123', os.environ) self.assertEqual(os.environ['AAAAAAAAAA456'], '123') self.assertEqual(os.environ['AAAAAAAAAA789'], '123') finally: for key in {'AAAAAAAAAA123', 'AAAAAAAAAA456', 'AAAAAAAAAA789'}: if key in os.environ: del os.environ[key] def test_test_connect_params_run_testcase(self): with self.environ(PGPORT='777'): self.run_testcase({ 'env': { 'PGUSER': '__test__' }, 'host': 'abc', 'result': ( [('abc', 5432)], {'user': '__test__', 'database': '__test__', 'target_session_attrs': 'any'} ) }) def test_connect_params(self): for testcase in self.TESTS: self.run_testcase(testcase) def test_connect_connection_service_file(self): connection_service_file = tempfile.NamedTemporaryFile( 'w+t', delete=False) connection_service_file.write(textwrap.dedent(''' [test_service_dbname] port=5433 host=somehost dbname=test_dbname user=admin password=test_password target_session_attrs=primary krbsrvname=fakekrbsrvname gsslib=sspi [test_service_database] port=5433 host=somehost database=test_dbname user=admin password=test_password target_session_attrs=primary krbsrvname=fakekrbsrvname gsslib=sspi ''')) connection_service_file.close() os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR) try: # Test connection service file with dbname self.run_testcase({ 'dsn': 'postgresql://?service=test_service_dbname', 'env': { 'PGSERVICEFILE': connection_service_file.name }, 'result': ( [('somehost', 5433)], { 'user': 'admin', 'password': 'test_password', 'database': 'test_dbname', 'target_session_attrs': 'primary', 'krbsrvname': 'fakekrbsrvname', 'gsslib': 'sspi', } ) }) # Test connection service file with database self.run_testcase({ 'dsn': 'postgresql://?service=test_service_database', 'env': { 'PGSERVICEFILE': connection_service_file.name }, 'result': ( [('somehost', 5433)], { 'user': 'admin', 'password': 'test_password', 'database': 'test_dbname', 'target_session_attrs': 'primary', 'krbsrvname': 'fakekrbsrvname', 'gsslib': 'sspi', } ) }) # Test that envvars are overridden by service file self.run_testcase({ 'dsn': 'postgresql://?service=test_service_dbname', 'env': { 'PGUSER': 'user', 'PGSERVICEFILE': connection_service_file.name }, 'result': ( [('somehost', 5433)], { 'user': 'admin', 'password': 'test_password', 'database': 'test_dbname', 'target_session_attrs': 'primary', 'krbsrvname': 'fakekrbsrvname', 'gsslib': 'sspi', } ) }) # Test that dsn params overwrite service file self.run_testcase({ 'dsn': 'postgresql://?service={}&dbname={}'.format( "test_service_dbname", "test_dbname_dsn" ), 'env': { 'PGSERVICEFILE': connection_service_file.name }, 'result': ( [('somehost', 5433)], { 'user': 'admin', 'password': 'test_password', 'database': 'test_dbname_dsn', 'target_session_attrs': 'primary', 'krbsrvname': 'fakekrbsrvname', 'gsslib': 'sspi', } ) }) finally: os.unlink(connection_service_file.name) def test_connect_pgpass_regular(self): passfile = tempfile.NamedTemporaryFile('w+t', delete=False) passfile.write(textwrap.dedent(R''' abc:*:*:user:password from pgpass for user@abc localhost:*:*:*:password from pgpass for localhost cde:5433:*:*:password from pgpass for cde:5433 *:*:*:testuser:password from pgpass for testuser *:*:testdb:*:password from pgpass for testdb # comment *:*:test\:db:test\\:password from pgpass with escapes ''')) passfile.close() os.chmod(passfile.name, stat.S_IWUSR | stat.S_IRUSR) try: # passfile path in env self.run_testcase({ 'env': { 'PGPASSFILE': passfile.name }, 'host': 'abc', 'user': 'user', 'database': 'db', 'result': ( [('abc', 5432)], { 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) # passfile path as explicit arg self.run_testcase({ 'host': 'abc', 'user': 'user', 'database': 'db', 'passfile': passfile.name, 'result': ( [('abc', 5432)], { 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) # passfile path in dsn self.run_testcase({ 'dsn': 'postgres://user@abc/db?passfile={}'.format( passfile.name), 'result': ( [('abc', 5432)], { 'password': 'password from pgpass for user@abc', 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) self.run_testcase({ 'host': 'localhost', 'user': 'user', 'database': 'db', 'passfile': passfile.name, 'result': ( [('localhost', 5432)], { 'password': 'password from pgpass for localhost', 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) if _system != 'Windows': # unix socket gets normalized as localhost self.run_testcase({ 'host': '/tmp', 'user': 'user', 'database': 'db', 'passfile': passfile.name, 'result': ( ['/tmp/.s.PGSQL.5432'], { 'password': 'password from pgpass for localhost', 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) # port matching (also tests that `:` can be part of password) self.run_testcase({ 'host': 'cde', 'port': 5433, 'user': 'user', 'database': 'db', 'passfile': passfile.name, 'result': ( [('cde', 5433)], { 'password': 'password from pgpass for cde:5433', 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) # user matching self.run_testcase({ 'host': 'def', 'user': 'testuser', 'database': 'db', 'passfile': passfile.name, 'result': ( [('def', 5432)], { 'password': 'password from pgpass for testuser', 'user': 'testuser', 'database': 'db', 'target_session_attrs': 'any', } ) }) # database matching self.run_testcase({ 'host': 'efg', 'user': 'user', 'database': 'testdb', 'passfile': passfile.name, 'result': ( [('efg', 5432)], { 'password': 'password from pgpass for testdb', 'user': 'user', 'database': 'testdb', 'target_session_attrs': 'any', } ) }) # test escaping self.run_testcase({ 'host': 'fgh', 'user': R'test\\', 'database': R'test\:db', 'passfile': passfile.name, 'result': ( [('fgh', 5432)], { 'password': 'password from pgpass with escapes', 'user': R'test\\', 'database': R'test\:db', 'target_session_attrs': 'any', } ) }) finally: os.unlink(passfile.name) @unittest.skipIf(_system == 'Windows', 'no mode checking on Windows') def test_connect_pgpass_badness_mode(self): # Verify that .pgpass permissions are checked with tempfile.NamedTemporaryFile('w+t') as passfile: os.chmod(passfile.name, stat.S_IWUSR | stat.S_IRUSR | stat.S_IWGRP | stat.S_IRGRP) with self.assertWarnsRegex( UserWarning, 'password file .* has group or world access'): self.run_testcase({ 'host': 'abc', 'user': 'user', 'database': 'db', 'passfile': passfile.name, 'result': ( [('abc', 5432)], { 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) def test_connect_pgpass_badness_non_file(self): # Verify warnings when .pgpass is not a file with tempfile.TemporaryDirectory() as passfile: with self.assertWarnsRegex( UserWarning, 'password file .* is not a plain file'): self.run_testcase({ 'host': 'abc', 'user': 'user', 'database': 'db', 'passfile': passfile, 'result': ( [('abc', 5432)], { 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) def test_connect_pgpass_nonexistent(self): # nonexistent passfile is OK self.run_testcase({ 'host': 'abc', 'user': 'user', 'database': 'db', 'passfile': 'totally nonexistent', 'result': ( [('abc', 5432)], { 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) @unittest.skipIf(_system == 'Windows', 'no mode checking on Windows') def test_connect_pgpass_inaccessible_file(self): with tempfile.NamedTemporaryFile('w+t') as passfile: os.chmod(passfile.name, stat.S_IWUSR) # nonexistent passfile is OK self.run_testcase({ 'host': 'abc', 'user': 'user', 'database': 'db', 'passfile': passfile.name, 'result': ( [('abc', 5432)], { 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) @unittest.skipIf(_system == 'Windows', 'no mode checking on Windows') def test_connect_pgpass_inaccessible_directory(self): with tempfile.TemporaryDirectory() as passdir: with tempfile.NamedTemporaryFile('w+t', dir=passdir) as passfile: os.chmod(passdir, stat.S_IWUSR) try: # nonexistent passfile is OK self.run_testcase({ 'host': 'abc', 'user': 'user', 'database': 'db', 'passfile': passfile.name, 'result': ( [('abc', 5432)], { 'user': 'user', 'database': 'db', 'target_session_attrs': 'any', } ) }) finally: os.chmod(passdir, stat.S_IRWXU) async def test_connect_args_validation(self): for val in {-1, 'a', True, False, 0}: with self.assertRaisesRegex(ValueError, 'greater than 0'): await asyncpg.connect(command_timeout=val) for arg in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', 'statement_cache_size'}: for val in {None, -1, True, False}: with self.assertRaisesRegex(ValueError, 'greater or equal'): await asyncpg.connect(**{arg: val}) class TestConnection(tb.ConnectedTestCase): async def test_connection_isinstance(self): self.assertTrue(isinstance(self.con, pg_connection.Connection)) self.assertTrue(isinstance(self.con, object)) self.assertFalse(isinstance(self.con, list)) async def test_connection_use_after_close(self): def check(): return self.assertRaisesRegex(asyncpg.InterfaceError, 'connection is closed') await self.con.close() with check(): await self.con.add_listener('aaa', lambda: None) with check(): self.con.transaction() with check(): await self.con.executemany('SELECT 1', []) with check(): await self.con.set_type_codec('aaa', encoder=None, decoder=None) with check(): await self.con.set_builtin_type_codec('aaa', codec_name='aaa') for meth in ('execute', 'fetch', 'fetchval', 'fetchrow', 'prepare', 'cursor'): with check(): await getattr(self.con, meth)('SELECT 1') with check(): await self.con.reset() @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') async def test_connection_ssl_to_no_ssl_server(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(SSL_CA_CERT_FILE) with self.assertRaisesRegex(ConnectionError, 'rejected SSL'): await self.connect( host='localhost', user='ssl_user', ssl=ssl_context) @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') async def test_connection_sslmode_no_ssl_server(self): async def verify_works(sslmode): con = None try: con = await self.connect( dsn='postgresql://foo/?sslmode=' + sslmode, user='postgres', database='postgres', host='localhost') self.assertEqual(await con.fetchval('SELECT 42'), 42) self.assertFalse(con._protocol.is_ssl) finally: if con: await con.close() async def verify_fails(sslmode): con = None try: with self.assertRaises(ConnectionError): con = await self.connect( dsn='postgresql://foo/?sslmode=' + sslmode, user='postgres', database='postgres', host='localhost') await con.fetchval('SELECT 42') finally: if con: await con.close() await verify_works('disable') await verify_works('allow') await verify_works('prefer') await verify_fails('require') with mock_dot_postgresql(): await verify_fails('require') await verify_fails('verify-ca') await verify_fails('verify-full') async def test_connection_implicit_host(self): conn_spec = self.get_connection_spec() con = await asyncpg.connect( port=conn_spec.get('port'), database=conn_spec.get('database'), user=conn_spec.get('user')) await con.close() @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') async def test_connection_no_home_dir(self): with mock_no_home_dir(): con = await self.connect( dsn='postgresql://foo/', user='postgres', database='postgres', host='localhost') await con.fetchval('SELECT 42') await con.close() with mock_dev_null_home_dir(): con = await self.connect( dsn='postgresql://foo/', user='postgres', database='postgres', host='localhost') await con.fetchval('SELECT 42') await con.close() with self.assertRaisesRegex( exceptions.ClientConfigurationError, r'root certificate file "~/\.postgresql/root\.crt" does not exist' ): with mock_no_home_dir(): await self.connect( host='localhost', user='ssl_user', ssl='verify-full') with self.assertRaisesRegex( exceptions.ClientConfigurationError, r'root certificate file ".*" does not exist' ): with mock_dev_null_home_dir(): await self.connect( host='localhost', user='ssl_user', ssl='verify-full') class BaseTestSSLConnection(tb.ConnectedTestCase): @classmethod def get_server_settings(cls): conf = super().get_server_settings() conf.update({ 'ssl': 'on', 'ssl_cert_file': SSL_CERT_FILE, 'ssl_key_file': SSL_KEY_FILE, 'ssl_ca_file': CLIENT_CA_CERT_FILE, }) if cls.cluster.get_pg_version() >= (12, 0): conf['ssl_min_protocol_version'] = 'TLSv1.2' conf['ssl_max_protocol_version'] = 'TLSv1.2' return conf @classmethod def setup_cluster(cls): cls.cluster = cls.new_cluster(pg_cluster.TempCluster) cls.start_cluster( cls.cluster, server_settings=cls.get_server_settings()) def setUp(self): super().setUp() self.cluster.reset_hba() create_script = [] create_script.append('CREATE ROLE ssl_user WITH LOGIN;') create_script.append('GRANT ALL ON SCHEMA public TO ssl_user;') self._add_hba_entry() # Put hba changes into effect self.cluster.reload() create_script = '\n'.join(create_script) self.loop.run_until_complete(self.con.execute(create_script)) def tearDown(self): # Reset cluster's pg_hba.conf since we've meddled with it self.cluster.trust_local_connections() drop_script = [] drop_script.append('REVOKE ALL ON SCHEMA public FROM ssl_user;') drop_script.append('DROP ROLE ssl_user;') drop_script = '\n'.join(drop_script) self.loop.run_until_complete(self.con.execute(drop_script)) super().tearDown() def _add_hba_entry(self): raise NotImplementedError() @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') class TestSSLConnection(BaseTestSSLConnection): def _add_hba_entry(self): self.cluster.add_hba_entry( type='hostssl', address=ipaddress.ip_network('127.0.0.0/24'), database='postgres', user='ssl_user', auth_method='trust') self.cluster.add_hba_entry( type='hostssl', address=ipaddress.ip_network('::1/128'), database='postgres', user='ssl_user', auth_method='trust') async def test_ssl_connection_custom_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(SSL_CA_CERT_FILE) con = await self.connect( host='localhost', user='ssl_user', ssl=ssl_context) try: self.assertEqual(await con.fetchval('SELECT 42'), 42) with self.assertRaises(asyncio.TimeoutError): await con.execute('SELECT pg_sleep(5)', timeout=0.5) self.assertEqual(await con.fetchval('SELECT 43'), 43) finally: await con.close() async def test_ssl_connection_sslmode(self): async def verify_works(sslmode, *, host='localhost'): con = None try: con = await self.connect( dsn='postgresql://foo/postgres?sslmode=' + sslmode, host=host, user='ssl_user') self.assertEqual(await con.fetchval('SELECT 42'), 42) self.assertTrue(con._protocol.is_ssl) finally: if con: await con.close() async def verify_fails(sslmode, *, host='localhost', exn_type): # XXX: uvloop artifact old_handler = self.loop.get_exception_handler() con = None try: self.loop.set_exception_handler(lambda *args: None) with self.assertRaises(exn_type): con = await self.connect( dsn='postgresql://foo/?sslmode=' + sslmode, host=host, user='ssl_user') await con.fetchval('SELECT 42') finally: if con: await con.close() self.loop.set_exception_handler(old_handler) invalid_auth_err = asyncpg.InvalidAuthorizationSpecificationError await verify_fails('disable', exn_type=invalid_auth_err) await verify_works('allow') await verify_works('prefer') await verify_works('require') await verify_fails('verify-ca', exn_type=ValueError) await verify_fails('verify-full', exn_type=ValueError) with mock_dot_postgresql(): await verify_works('require') await verify_works('verify-ca') await verify_works('verify-ca', host='127.0.0.1') await verify_works('verify-full') await verify_fails('verify-full', host='127.0.0.1', exn_type=ssl.CertificateError) with mock_dot_postgresql(crl=True): await verify_fails('disable', exn_type=invalid_auth_err) await verify_works('allow') await verify_works('prefer') await verify_fails('require', exn_type=ssl.SSLError) await verify_fails('verify-ca', exn_type=ssl.SSLError) await verify_fails('verify-ca', host='127.0.0.1', exn_type=ssl.SSLError) await verify_fails('verify-full', exn_type=ssl.SSLError) async def test_ssl_connection_default_context(self): # XXX: uvloop artifact old_handler = self.loop.get_exception_handler() try: self.loop.set_exception_handler(lambda *args: None) with self.assertRaisesRegex(ssl.SSLError, 'verify failed'): await self.connect( host='localhost', user='ssl_user', ssl=True) finally: self.loop.set_exception_handler(old_handler) async def test_ssl_connection_pool(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(SSL_CA_CERT_FILE) pool = await self.create_pool( host='localhost', user='ssl_user', database='postgres', min_size=5, max_size=10, ssl=ssl_context) async def worker(): async with pool.acquire() as con: self.assertEqual(await con.fetchval('SELECT 42'), 42) with self.assertRaises(asyncio.TimeoutError): await con.execute('SELECT pg_sleep(5)', timeout=0.5) self.assertEqual(await con.fetchval('SELECT 43'), 43) tasks = [worker() for _ in range(100)] await asyncio.gather(*tasks) await pool.close() async def test_executemany_uvloop_ssl_issue_700(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(SSL_CA_CERT_FILE) con = await self.connect( host='localhost', user='ssl_user', ssl=ssl_context) try: await con.execute('CREATE TABLE test_many (v int)') await con.executemany( 'INSERT INTO test_many VALUES ($1)', [(x + 1,) for x in range(100)] ) self.assertEqual( await con.fetchval('SELECT sum(v) FROM test_many'), 5050 ) finally: try: await con.execute('DROP TABLE IF EXISTS test_many') finally: await con.close() async def test_tls_version(self): if self.cluster.get_pg_version() < (12, 0): self.skipTest("PostgreSQL < 12 cannot set ssl protocol version") # XXX: uvloop artifact old_handler = self.loop.get_exception_handler() with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="ssl.TLSVersion.TLSv1_1 is deprecated", category=DeprecationWarning ) try: self.loop.set_exception_handler(lambda *args: None) with self.assertRaisesRegex( ssl.SSLError, '(protocol version)|(handshake failure)', ): await self.connect( dsn='postgresql://ssl_user@localhost/postgres' '?sslmode=require&ssl_min_protocol_version=TLSv1.3' ) with self.assertRaises((ssl.SSLError, ConnectionResetError)): await self.connect( dsn='postgresql://ssl_user@localhost/postgres' '?sslmode=require' '&ssl_min_protocol_version=TLSv1.1' '&ssl_max_protocol_version=TLSv1.1' ) if not ssl.OPENSSL_VERSION.startswith('LibreSSL'): with self.assertRaisesRegex(ssl.SSLError, 'no protocols'): await self.connect( dsn='postgresql://ssl_user@localhost/postgres' '?sslmode=require' '&ssl_min_protocol_version=TLSv1.2' '&ssl_max_protocol_version=TLSv1.1' ) con = await self.connect( dsn='postgresql://ssl_user@localhost/postgres' '?sslmode=require' '&ssl_min_protocol_version=TLSv1.2' '&ssl_max_protocol_version=TLSv1.2' ) try: self.assertEqual(await con.fetchval('SELECT 42'), 42) finally: await con.close() finally: self.loop.set_exception_handler(old_handler) @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') class TestClientSSLConnection(BaseTestSSLConnection): def _add_hba_entry(self): self.cluster.add_hba_entry( type='hostssl', address=ipaddress.ip_network('127.0.0.0/24'), database='postgres', user='ssl_user', auth_method='cert') self.cluster.add_hba_entry( type='hostssl', address=ipaddress.ip_network('::1/128'), database='postgres', user='ssl_user', auth_method='cert') async def test_ssl_connection_client_auth_fails_with_wrong_setup(self): ssl_context = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, cafile=SSL_CA_CERT_FILE, ) with self.assertRaisesRegex( exceptions.InvalidAuthorizationSpecificationError, "requires a valid client certificate", ): await self.connect( host='localhost', user='ssl_user', ssl=ssl_context, ) async def _test_works(self, **conn_args): con = await self.connect(**conn_args) try: self.assertEqual(await con.fetchval('SELECT 42'), 42) finally: await con.close() async def test_ssl_connection_client_auth_custom_context(self): for key_file in (CLIENT_SSL_KEY_FILE, CLIENT_SSL_PROTECTED_KEY_FILE): ssl_context = ssl.create_default_context( ssl.Purpose.SERVER_AUTH, cafile=SSL_CA_CERT_FILE, ) ssl_context.load_cert_chain( CLIENT_SSL_CERT_FILE, keyfile=key_file, password='secRet', ) await self._test_works( host='localhost', user='ssl_user', ssl=ssl_context, ) async def test_ssl_connection_client_auth_dsn(self): params = { 'sslrootcert': SSL_CA_CERT_FILE, 'sslcert': CLIENT_SSL_CERT_FILE, 'sslkey': CLIENT_SSL_KEY_FILE, 'sslmode': 'verify-full', } params_str = urllib.parse.urlencode(params) dsn = 'postgres://ssl_user@localhost/postgres?' + params_str await self._test_works(dsn=dsn) params['sslkey'] = CLIENT_SSL_PROTECTED_KEY_FILE params['sslpassword'] = 'secRet' params_str = urllib.parse.urlencode(params) dsn = 'postgres://ssl_user@localhost/postgres?' + params_str await self._test_works(dsn=dsn) async def test_ssl_connection_client_auth_env(self): env = { 'PGSSLROOTCERT': SSL_CA_CERT_FILE, 'PGSSLCERT': CLIENT_SSL_CERT_FILE, 'PGSSLKEY': CLIENT_SSL_KEY_FILE, } dsn = 'postgres://ssl_user@localhost/postgres?sslmode=verify-full' with unittest.mock.patch.dict('os.environ', env): await self._test_works(dsn=dsn) env['PGSSLKEY'] = CLIENT_SSL_PROTECTED_KEY_FILE with unittest.mock.patch.dict('os.environ', env): await self._test_works(dsn=dsn + '&sslpassword=secRet') async def test_ssl_connection_client_auth_dot_postgresql(self): dsn = 'postgres://ssl_user@localhost/postgres?sslmode=verify-full' with mock_dot_postgresql(client=True): await self._test_works(dsn=dsn) with mock_dot_postgresql(client=True, protected=True): await self._test_works(dsn=dsn + '&sslpassword=secRet') @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') class TestNoSSLConnection(BaseTestSSLConnection): def _add_hba_entry(self): self.cluster.add_hba_entry( type='hostnossl', address=ipaddress.ip_network('127.0.0.0/24'), database='postgres', user='ssl_user', auth_method='trust') self.cluster.add_hba_entry( type='hostnossl', address=ipaddress.ip_network('::1/128'), database='postgres', user='ssl_user', auth_method='trust') async def test_nossl_connection_sslmode(self): async def verify_works(sslmode, *, host='localhost'): con = None try: con = await self.connect( dsn='postgresql://foo/postgres?sslmode=' + sslmode, host=host, user='ssl_user') self.assertEqual(await con.fetchval('SELECT 42'), 42) self.assertFalse(con._protocol.is_ssl) finally: if con: await con.close() async def verify_fails(sslmode, *, host='localhost'): # XXX: uvloop artifact old_handler = self.loop.get_exception_handler() con = None try: self.loop.set_exception_handler(lambda *args: None) with self.assertRaises( asyncpg.InvalidAuthorizationSpecificationError ): con = await self.connect( dsn='postgresql://foo/?sslmode=' + sslmode, host=host, user='ssl_user') await con.fetchval('SELECT 42') finally: if con: await con.close() self.loop.set_exception_handler(old_handler) await verify_works('disable') await verify_works('allow') await verify_works('prefer') await verify_fails('require') with mock_dot_postgresql(): await verify_fails('require') await verify_fails('verify-ca') await verify_fails('verify-full') async def test_nossl_connection_prefer_cancel(self): con = await self.connect( dsn='postgresql://foo/postgres?sslmode=prefer', host='localhost', user='ssl_user') try: self.assertFalse(con._protocol.is_ssl) with self.assertRaises(asyncio.TimeoutError): await con.execute('SELECT pg_sleep(5)', timeout=0.5) val = await con.fetchval('SELECT 123') self.assertEqual(val, 123) finally: await con.close() async def test_nossl_connection_pool(self): pool = await self.create_pool( host='localhost', user='ssl_user', database='postgres', min_size=5, max_size=10, ssl='prefer') async def worker(): async with pool.acquire() as con: self.assertFalse(con._protocol.is_ssl) self.assertEqual(await con.fetchval('SELECT 42'), 42) with self.assertRaises(asyncio.TimeoutError): await con.execute('SELECT pg_sleep(5)', timeout=0.5) self.assertEqual(await con.fetchval('SELECT 43'), 43) tasks = [worker() for _ in range(100)] await asyncio.gather(*tasks) await pool.close() class TestConnectionGC(tb.ClusterTestCase): async def _run_no_explicit_close_test(self): gc_was_enabled = gc.isenabled() gc.disable() try: con = await self.connect() await con.fetchval("select 123") proto = con._protocol conref = weakref.ref(con) del con self.assertIsNone(conref()) self.assertTrue(proto.is_closed()) # tick event loop; asyncio.selector_events._SelectorSocketTransport # needs a chance to close itself and remove its reference to proto await asyncio.sleep(0) protoref = weakref.ref(proto) del proto self.assertIsNone(protoref()) finally: if gc_was_enabled: gc.enable() async def test_no_explicit_close_no_debug(self): olddebug = self.loop.get_debug() self.loop.set_debug(False) try: with self.assertWarnsRegex( ResourceWarning, r'unclosed connection.*run in asyncio debug'): await self._run_no_explicit_close_test() finally: self.loop.set_debug(olddebug) async def test_no_explicit_close_with_debug(self): olddebug = self.loop.get_debug() self.loop.set_debug(True) try: with self.assertWarnsRegex(ResourceWarning, r'unclosed connection') as rw: await self._run_no_explicit_close_test() msg = " ".join(rw.warning.args) self.assertIn(' created at:\n', msg) self.assertIn('in test_no_explicit_close_with_debug', msg) finally: self.loop.set_debug(olddebug) class TestConnectionAttributes(tb.HotStandbyTestCase): async def _run_connection_test( self, connect, target_attribute, expected_port ): conn = await connect(target_session_attrs=target_attribute) self.assertTrue(_get_connected_host(conn).endswith(expected_port)) await conn.close() async def test_target_server_attribute_port(self): master_port = self.master_cluster.get_connection_spec()['port'] standby_port = self.standby_cluster.get_connection_spec()['port'] tests = [ (self.connect_primary, 'primary', master_port), (self.connect_standby, 'standby', standby_port), ] for connect, target_attr, expected_port in tests: await self._run_connection_test( connect, target_attr, expected_port ) if self.master_cluster.get_pg_version()[0] < 14: self.skipTest("PostgreSQL<14 does not support these features") tests = [ (self.connect_primary, 'read-write', master_port), (self.connect_standby, 'read-only', standby_port), ] for connect, target_attr, expected_port in tests: await self._run_connection_test( connect, target_attr, expected_port ) async def test_target_attribute_not_matched(self): tests = [ (self.connect_standby, 'primary'), (self.connect_primary, 'standby'), ] for connect, target_attr in tests: with self.assertRaises(exceptions.TargetServerAttributeNotMatched): await connect(target_session_attrs=target_attr) if self.master_cluster.get_pg_version()[0] < 14: self.skipTest("PostgreSQL<14 does not support these features") tests = [ (self.connect_standby, 'read-write'), (self.connect_primary, 'read-only'), ] for connect, target_attr in tests: with self.assertRaises(exceptions.TargetServerAttributeNotMatched): await connect(target_session_attrs=target_attr) async def test_prefer_standby_when_standby_is_up(self): con = await self.connect(target_session_attrs='prefer-standby') standby_port = self.standby_cluster.get_connection_spec()['port'] connected_host = _get_connected_host(con) self.assertTrue(connected_host.endswith(standby_port)) await con.close() async def test_prefer_standby_picks_master_when_standby_is_down(self): primary_spec = self.get_cluster_connection_spec(self.master_cluster) connection_spec = { 'host': [ primary_spec['host'], 'unlocalhost', ], 'port': [primary_spec['port'], 15345], 'database': primary_spec['database'], 'user': primary_spec['user'], 'target_session_attrs': 'prefer-standby' } con = await self.connect(**connection_spec) master_port = self.master_cluster.get_connection_spec()['port'] connected_host = _get_connected_host(con) self.assertTrue(connected_host.endswith(master_port)) await con.close() def _get_connected_host(con): peername = con._transport.get_extra_info('peername') if isinstance(peername, tuple): peername = "".join((str(s) for s in peername if s)) return peername ================================================ FILE: tests/test_copy.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import datetime import io import os import tempfile import unittest import asyncpg from asyncpg import _testbase as tb class TestCopyFrom(tb.ConnectedTestCase): async def test_copy_from_table_basics(self): await self.con.execute(''' CREATE TABLE copytab(a text, "b~" text, i int); INSERT INTO copytab (a, "b~", i) ( SELECT 'a' || i::text, 'b' || i::text, i FROM generate_series(1, 5) AS i ); INSERT INTO copytab (a, "b~", i) VALUES('*', NULL, NULL); ''') try: f = io.BytesIO() # Basic functionality. res = await self.con.copy_from_table('copytab', output=f) self.assertEqual(res, 'COPY 6') output = f.getvalue().decode().split('\n') self.assertEqual( output, [ 'a1\tb1\t1', 'a2\tb2\t2', 'a3\tb3\t3', 'a4\tb4\t4', 'a5\tb5\t5', '*\t\\N\t\\N', '' ] ) # Test parameters. await self.con.execute('SET search_path=none') f.seek(0) f.truncate() res = await self.con.copy_from_table( 'copytab', output=f, columns=('a', 'b~'), schema_name='public', format='csv', delimiter='|', null='n-u-l-l', header=True, quote='*', escape='!', force_quote=('a',)) output = f.getvalue().decode().split('\n') self.assertEqual( output, [ 'a|b~', '*a1*|b1', '*a2*|b2', '*a3*|b3', '*a4*|b4', '*a5*|b5', '*!**|n-u-l-l', '' ] ) await self.con.execute('SET search_path=public') finally: await self.con.execute('DROP TABLE public.copytab') async def test_copy_from_table_large_rows(self): await self.con.execute(''' CREATE TABLE copytab(a text, b text); INSERT INTO copytab (a, b) ( SELECT repeat('a' || i::text, 500000), repeat('b' || i::text, 500000) FROM generate_series(1, 5) AS i ); ''') try: f = io.BytesIO() # Basic functionality. res = await self.con.copy_from_table('copytab', output=f) self.assertEqual(res, 'COPY 5') output = f.getvalue().decode().split('\n') self.assertEqual( output, [ 'a1' * 500000 + '\t' + 'b1' * 500000, 'a2' * 500000 + '\t' + 'b2' * 500000, 'a3' * 500000 + '\t' + 'b3' * 500000, 'a4' * 500000 + '\t' + 'b4' * 500000, 'a5' * 500000 + '\t' + 'b5' * 500000, '' ] ) finally: await self.con.execute('DROP TABLE public.copytab') async def test_copy_from_query_basics(self): f = io.BytesIO() res = await self.con.copy_from_query(''' SELECT repeat('a' || i::text, 500000), repeat('b' || i::text, 500000) FROM generate_series(1, 5) AS i ''', output=f) self.assertEqual(res, 'COPY 5') output = f.getvalue().decode().split('\n') self.assertEqual( output, [ 'a1' * 500000 + '\t' + 'b1' * 500000, 'a2' * 500000 + '\t' + 'b2' * 500000, 'a3' * 500000 + '\t' + 'b3' * 500000, 'a4' * 500000 + '\t' + 'b4' * 500000, 'a5' * 500000 + '\t' + 'b5' * 500000, '' ] ) async def test_copy_from_query_with_args(self): f = io.BytesIO() res = await self.con.copy_from_query(''' SELECT i, i * 10, $2::text FROM generate_series(1, 5) AS i WHERE i = $1 ''', 3, None, output=f) self.assertEqual(res, 'COPY 1') output = f.getvalue().decode().split('\n') self.assertEqual( output, [ '3\t30\t\\N', '' ] ) async def test_copy_from_query_to_path(self): with tempfile.NamedTemporaryFile() as f: f.close() await self.con.copy_from_query(''' SELECT i, i * 10 FROM generate_series(1, 5) AS i WHERE i = $1 ''', 3, output=f.name) with open(f.name, 'rb') as fr: output = fr.read().decode().split('\n') self.assertEqual( output, [ '3\t30', '' ] ) async def test_copy_from_query_to_path_like(self): with tempfile.NamedTemporaryFile() as f: f.close() class Path: def __init__(self, path): self.path = path def __fspath__(self): return self.path await self.con.copy_from_query(''' SELECT i, i * 10 FROM generate_series(1, 5) AS i WHERE i = $1 ''', 3, output=Path(f.name)) with open(f.name, 'rb') as fr: output = fr.read().decode().split('\n') self.assertEqual( output, [ '3\t30', '' ] ) async def test_copy_from_query_to_bad_output(self): with self.assertRaisesRegex(TypeError, 'output is expected to be'): await self.con.copy_from_query(''' SELECT i, i * 10 FROM generate_series(1, 5) AS i WHERE i = $1 ''', 3, output=1) async def test_copy_from_query_to_sink(self): with tempfile.NamedTemporaryFile() as f: async def writer(data): # Sleeping here to simulate slow output sink to test # backpressure. await asyncio.sleep(0.05) f.write(data) await self.con.copy_from_query(''' SELECT repeat('a', 500) FROM generate_series(1, 5000) AS i ''', output=writer) f.seek(0) output = f.read().decode().split('\n') self.assertEqual( output, [ 'a' * 500 ] * 5000 + [''] ) self.assertEqual(await self.con.fetchval('SELECT 1'), 1) async def test_copy_from_query_cancellation_explicit(self): async def writer(data): # Sleeping here to simulate slow output sink to test # backpressure. await asyncio.sleep(0.5) coro = self.con.copy_from_query(''' SELECT repeat('a', 500) FROM generate_series(1, 5000) AS i ''', output=writer) task = self.loop.create_task(coro) await asyncio.sleep(0.7) task.cancel() with self.assertRaises(asyncio.CancelledError): await task self.assertEqual(await self.con.fetchval('SELECT 1'), 1) async def test_copy_from_query_cancellation_on_sink_error(self): async def writer(data): await asyncio.sleep(0.05) raise RuntimeError('failure') coro = self.con.copy_from_query(''' SELECT repeat('a', 500) FROM generate_series(1, 5000) AS i ''', output=writer) task = self.loop.create_task(coro) with self.assertRaises(RuntimeError): await task self.assertEqual(await self.con.fetchval('SELECT 1'), 1) async def test_copy_from_query_cancellation_while_waiting_for_data(self): async def writer(data): pass coro = self.con.copy_from_query(''' SELECT pg_sleep(60) FROM generate_series(1, 5000) AS i ''', output=writer) task = self.loop.create_task(coro) await asyncio.sleep(0.7) task.cancel() with self.assertRaises(asyncio.CancelledError): await task self.assertEqual(await self.con.fetchval('SELECT 1'), 1) async def test_copy_from_query_timeout_1(self): async def writer(data): await asyncio.sleep(0.05) coro = self.con.copy_from_query(''' SELECT repeat('a', 500) FROM generate_series(1, 5000) AS i ''', output=writer, timeout=0.10) task = self.loop.create_task(coro) with self.assertRaises(asyncio.TimeoutError): await task self.assertEqual(await self.con.fetchval('SELECT 1'), 1) async def test_copy_from_query_timeout_2(self): async def writer(data): try: await asyncio.sleep(10) except asyncio.TimeoutError: raise else: self.fail('TimeoutError not raised') coro = self.con.copy_from_query(''' SELECT repeat('a', 500) FROM generate_series(1, 5000) AS i ''', output=writer, timeout=0.10) task = self.loop.create_task(coro) with self.assertRaises(asyncio.TimeoutError): await task self.assertEqual(await self.con.fetchval('SELECT 1'), 1) class TestCopyTo(tb.ConnectedTestCase): async def test_copy_to_table_basics(self): await self.con.execute(''' CREATE TABLE copytab(a text, "b~" text, i int); ''') try: f = io.BytesIO() f.write( '\n'.join([ 'a1\tb1\t1', 'a2\tb2\t2', 'a3\tb3\t3', 'a4\tb4\t4', 'a5\tb5\t5', '*\t\\N\t\\N', '' ]).encode('utf-8') ) f.seek(0) res = await self.con.copy_to_table('copytab', source=f) self.assertEqual(res, 'COPY 6') output = await self.con.fetch(""" SELECT * FROM copytab ORDER BY a """) self.assertEqual( output, [ ('*', None, None), ('a1', 'b1', 1), ('a2', 'b2', 2), ('a3', 'b3', 3), ('a4', 'b4', 4), ('a5', 'b5', 5), ] ) # Test parameters. await self.con.execute('TRUNCATE copytab') await self.con.execute('SET search_path=none') f.seek(0) f.truncate() f.write( '\n'.join([ 'a|b~', '*a1*|b1', '*a2*|b2', '*a3*|b3', '*a4*|b4', '*a5*|b5', '*!**|*n-u-l-l*', 'n-u-l-l|bb', ]).encode('utf-8') ) f.seek(0) if self.con.get_server_version() < (9, 4): force_null = None forced_null_expected = 'n-u-l-l' else: force_null = ('b~',) forced_null_expected = None res = await self.con.copy_to_table( 'copytab', source=f, columns=('a', 'b~'), schema_name='public', format='csv', delimiter='|', null='n-u-l-l', header=True, quote='*', escape='!', force_not_null=('a',), force_null=force_null) self.assertEqual(res, 'COPY 7') await self.con.execute('SET search_path=public') output = await self.con.fetch(""" SELECT * FROM copytab ORDER BY a """) self.assertEqual( output, [ ('*', forced_null_expected, None), ('a1', 'b1', None), ('a2', 'b2', None), ('a3', 'b3', None), ('a4', 'b4', None), ('a5', 'b5', None), ('n-u-l-l', 'bb', None), ] ) finally: await self.con.execute('DROP TABLE public.copytab') async def test_copy_to_table_large_rows(self): await self.con.execute(''' CREATE TABLE copytab(a text, b text); ''') try: class _Source: def __init__(self): self.rowcount = 0 def __aiter__(self): return self async def __anext__(self): if self.rowcount >= 100: raise StopAsyncIteration else: self.rowcount += 1 return b'a1' * 500000 + b'\t' + b'b1' * 500000 + b'\n' res = await self.con.copy_to_table('copytab', source=_Source()) self.assertEqual(res, 'COPY 100') finally: await self.con.execute('DROP TABLE copytab') async def test_copy_to_table_from_bytes_like(self): await self.con.execute(''' CREATE TABLE copytab(a text, b text); ''') try: data = memoryview((b'a1' * 500 + b'\t' + b'b1' * 500 + b'\n') * 2) res = await self.con.copy_to_table('copytab', source=data) self.assertEqual(res, 'COPY 2') finally: await self.con.execute('DROP TABLE copytab') async def test_copy_to_table_fail_in_source_1(self): await self.con.execute(''' CREATE TABLE copytab(a text, b text); ''') try: class _Source: def __init__(self): self.rowcount = 0 def __aiter__(self): return self async def __anext__(self): raise RuntimeError('failure in source') with self.assertRaisesRegex(RuntimeError, 'failure in source'): await self.con.copy_to_table('copytab', source=_Source()) # Check that the protocol has recovered. self.assertEqual(await self.con.fetchval('SELECT 1'), 1) finally: await self.con.execute('DROP TABLE copytab') async def test_copy_to_table_fail_in_source_2(self): await self.con.execute(''' CREATE TABLE copytab(a text, b text); ''') try: class _Source: def __init__(self): self.rowcount = 0 def __aiter__(self): return self async def __anext__(self): if self.rowcount == 0: self.rowcount += 1 return b'a\tb\n' else: raise RuntimeError('failure in source') with self.assertRaisesRegex(RuntimeError, 'failure in source'): await self.con.copy_to_table('copytab', source=_Source()) # Check that the protocol has recovered. self.assertEqual(await self.con.fetchval('SELECT 1'), 1) finally: await self.con.execute('DROP TABLE copytab') async def test_copy_to_table_timeout(self): await self.con.execute(''' CREATE TABLE copytab(a text, b text); ''') try: class _Source: def __init__(self, loop): self.rowcount = 0 self.loop = loop def __aiter__(self): return self async def __anext__(self): self.rowcount += 1 await asyncio.sleep(60) return b'a1' * 50 + b'\t' + b'b1' * 50 + b'\n' with self.assertRaises(asyncio.TimeoutError): await self.con.copy_to_table( 'copytab', source=_Source(self.loop), timeout=0.10) # Check that the protocol has recovered. self.assertEqual(await self.con.fetchval('SELECT 1'), 1) finally: await self.con.execute('DROP TABLE copytab') async def test_copy_to_table_from_file_path(self): await self.con.execute(''' CREATE TABLE copytab(a text, "b~" text, i int); ''') f = tempfile.NamedTemporaryFile(delete=False) try: f.write( '\n'.join([ 'a1\tb1\t1', 'a2\tb2\t2', 'a3\tb3\t3', 'a4\tb4\t4', 'a5\tb5\t5', '*\t\\N\t\\N', '' ]).encode('utf-8') ) f.close() res = await self.con.copy_to_table('copytab', source=f.name) self.assertEqual(res, 'COPY 6') output = await self.con.fetch(""" SELECT * FROM copytab ORDER BY a """) self.assertEqual( output, [ ('*', None, None), ('a1', 'b1', 1), ('a2', 'b2', 2), ('a3', 'b3', 3), ('a4', 'b4', 4), ('a5', 'b5', 5), ] ) finally: await self.con.execute('DROP TABLE public.copytab') os.unlink(f.name) async def test_copy_records_to_table_1(self): await self.con.execute(''' CREATE TABLE copytab(a text, b int, c timestamptz); ''') try: date = datetime.datetime.now(tz=datetime.timezone.utc) delta = datetime.timedelta(days=1) records = [ ('a-{}'.format(i), i, date + delta) for i in range(100) ] records.append(('a-100', None, None)) res = await self.con.copy_records_to_table( 'copytab', records=records) self.assertEqual(res, 'COPY 101') finally: await self.con.execute('DROP TABLE copytab') async def test_copy_records_to_table_where(self): if not self.con._server_caps.sql_copy_from_where: raise unittest.SkipTest( 'COPY WHERE not supported on server') await self.con.execute(''' CREATE TABLE copytab_where(a text, b int, c timestamptz); ''') try: date = datetime.datetime.now(tz=datetime.timezone.utc) delta = datetime.timedelta(days=1) records = [ ('a-{}'.format(i), i, date + delta) for i in range(100) ] records.append(('a-100', None, None)) records.append(('b-999', None, None)) res = await self.con.copy_records_to_table( 'copytab_where', records=records, where='a <> \'b-999\'') self.assertEqual(res, 'COPY 101') finally: await self.con.execute('DROP TABLE copytab_where') async def test_copy_records_to_table_async(self): await self.con.execute(''' CREATE TABLE copytab_async(a text, b int, c timestamptz); ''') try: date = datetime.datetime.now(tz=datetime.timezone.utc) delta = datetime.timedelta(days=1) async def record_generator(): for i in range(100): yield ('a-{}'.format(i), i, date + delta) yield ('a-100', None, None) res = await self.con.copy_records_to_table( 'copytab_async', records=record_generator(), ) self.assertEqual(res, 'COPY 101') finally: await self.con.execute('DROP TABLE copytab_async') async def test_copy_records_to_table_no_binary_codec(self): await self.con.execute(''' CREATE TABLE copytab(a uuid); ''') try: def _encoder(value): return value def _decoder(value): return value await self.con.set_type_codec( 'uuid', encoder=_encoder, decoder=_decoder, schema='pg_catalog', format='text' ) records = [('2975ab9a-f79c-4ab4-9be5-7bc134d952f0',)] with self.assertRaisesRegex( asyncpg.InternalClientError, 'no binary format encoder'): await self.con.copy_records_to_table( 'copytab', records=records) finally: await self.con.reset_type_codec( 'uuid', schema='pg_catalog' ) await self.con.execute('DROP TABLE copytab') ================================================ FILE: tests/test_cursor.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncpg import inspect from asyncpg import _testbase as tb class TestIterableCursor(tb.ConnectedTestCase): async def test_cursor_iterable_01(self): st = await self.con.prepare('SELECT generate_series(0, 20)') expected = await st.fetch() for prefetch in range(1, 25): with self.subTest(prefetch=prefetch): async with self.con.transaction(): result = [] async for rec in st.cursor(prefetch=prefetch): result.append(rec) self.assertEqual( result, expected, 'result != expected for prefetch={}'.format(prefetch)) async def test_cursor_iterable_02(self): # Test that it's not possible to create a cursor without hold # outside of a transaction s = await self.con.prepare( 'DECLARE t BINARY CURSOR WITHOUT HOLD FOR SELECT 1') with self.assertRaises(asyncpg.NoActiveSQLTransactionError): await s.fetch() # Now test that statement.cursor() does not let you # iterate over it outside of a transaction st = await self.con.prepare('SELECT generate_series(0, 20)') it = st.cursor(prefetch=5).__aiter__() if inspect.isawaitable(it): it = await it with self.assertRaisesRegex(asyncpg.NoActiveSQLTransactionError, 'cursor cannot be created.*transaction'): await it.__anext__() async def test_cursor_iterable_03(self): st = await self.con.prepare('SELECT generate_series(0, 20)') it = st.cursor().__aiter__() if inspect.isawaitable(it): it = await it st._state.mark_closed() with self.assertRaisesRegex(asyncpg.InterfaceError, 'statement is closed'): async for _ in it: # NOQA pass async def test_cursor_iterable_04(self): st = await self.con.prepare('SELECT generate_series(0, 20)') st._state.mark_closed() with self.assertRaisesRegex(asyncpg.InterfaceError, 'statement is closed'): async for _ in st.cursor(): # NOQA pass async def test_cursor_iterable_05(self): st = await self.con.prepare('SELECT generate_series(0, 20)') for prefetch in range(-1, 1): with self.subTest(prefetch=prefetch): with self.assertRaisesRegex(asyncpg.InterfaceError, 'must be greater than zero'): async for _ in st.cursor(prefetch=prefetch): # NOQA pass async def test_cursor_iterable_06(self): recs = [] async with self.con.transaction(): await self.con.execute(''' CREATE TABLE cursor_iterable_06 (id int); INSERT INTO cursor_iterable_06 VALUES (0), (1); ''') try: cur = self.con.cursor('SELECT * FROM cursor_iterable_06') async for rec in cur: recs.append(rec) finally: # Check that after iteration has exhausted the cursor, # its associated portal is closed properly, unlocking # the table. await self.con.execute('DROP TABLE cursor_iterable_06') self.assertEqual(recs, [(i,) for i in range(2)]) class TestCursor(tb.ConnectedTestCase): async def test_cursor_01(self): st = await self.con.prepare('SELECT generate_series(0, 20)') with self.assertRaisesRegex(asyncpg.NoActiveSQLTransactionError, 'cursor cannot be created.*transaction'): await st.cursor() async def test_cursor_02(self): st = await self.con.prepare('SELECT generate_series(0, 20)') async with self.con.transaction(): cur = await st.cursor() for i in range(-1, 1): with self.assertRaisesRegex(asyncpg.InterfaceError, 'greater than zero'): await cur.fetch(i) res = await cur.fetch(2) self.assertEqual(res, [(0,), (1,)]) rec = await cur.fetchrow() self.assertEqual(rec, (2,)) r = repr(cur) self.assertTrue(r.startswith(' # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncpg from asyncpg import _testbase as tb class TestExceptions(tb.ConnectedTestCase): def test_exceptions_exported(self): for err in ('PostgresError', 'SubstringError', 'InterfaceError'): self.assertTrue(hasattr(asyncpg, err)) self.assertIn(err, asyncpg.__all__) for err in ('PostgresMessage',): self.assertFalse(hasattr(asyncpg, err)) self.assertNotIn(err, asyncpg.__all__) self.assertIsNone(asyncpg.PostgresError.schema_name) async def test_exceptions_unpacking(self): try: await self.con.execute('SELECT * FROM _nonexistent_') except asyncpg.UndefinedTableError as e: self.assertEqual(e.sqlstate, '42P01') self.assertEqual(e.position, '15') self.assertEqual(e.query, 'SELECT * FROM _nonexistent_') self.assertIsNotNone(e.severity) else: self.fail('UndefinedTableError not raised') async def test_exceptions_str(self): try: await self.con.execute(''' CREATE FUNCTION foo() RETURNS bool AS $$ $$ LANGUAGE SQL; ''') except asyncpg.InvalidFunctionDefinitionError as e: if self.server_version < (17, 0): detail = ( "Function's final statement must be SELECT or " "INSERT/UPDATE/DELETE RETURNING." ) else: detail = ( "Function's final statement must be SELECT or " "INSERT/UPDATE/DELETE/MERGE RETURNING." ) self.assertEqual(e.detail, detail) self.assertIn('DETAIL: Function', str(e)) else: self.fail('InvalidFunctionDefinitionError not raised') ================================================ FILE: tests/test_execute.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import asyncpg from asyncpg import _testbase as tb from asyncpg import exceptions class TestExecuteScript(tb.ConnectedTestCase): async def test_execute_script_1(self): self.assertEqual(self.con._protocol.queries_count, 0) status = await self.con.execute(''' SELECT 1; SELECT true FROM pg_type WHERE false = true; SELECT generate_series(0, 9); ''') self.assertEqual(self.con._protocol.queries_count, 1) self.assertEqual(status, 'SELECT 10') async def test_execute_script_2(self): status = await self.con.execute(''' CREATE TABLE mytab (a int); ''') self.assertEqual(status, 'CREATE TABLE') try: status = await self.con.execute(''' INSERT INTO mytab (a) VALUES ($1), ($2) ''', 10, 20) self.assertEqual(status, 'INSERT 0 2') finally: await self.con.execute('DROP TABLE mytab') async def test_execute_script_3(self): with self.assertRaisesRegex(asyncpg.PostgresSyntaxError, 'cannot insert multiple commands'): await self.con.execute(''' CREATE TABLE mytab (a int); INSERT INTO mytab (a) VALUES ($1), ($2); ''', 10, 20) async def test_execute_script_check_transactionality(self): with self.assertRaises(asyncpg.PostgresError): await self.con.execute(''' CREATE TABLE mytab (a int); SELECT * FROM mytab WHERE 1 / 0 = 1; ''') with self.assertRaisesRegex(asyncpg.PostgresError, '"mytab" does not exist'): await self.con.prepare(''' SELECT * FROM mytab ''') async def test_execute_exceptions_1(self): with self.assertRaisesRegex(asyncpg.PostgresError, 'relation "__dne__" does not exist'): await self.con.execute('select * from __dne__') async def test_execute_script_interrupted_close(self): fut = self.loop.create_task( self.con.execute('''SELECT pg_sleep(10)''')) await asyncio.sleep(0.2) self.assertFalse(self.con.is_closed()) await self.con.close() self.assertTrue(self.con.is_closed()) with self.assertRaises(asyncpg.QueryCanceledError): await fut async def test_execute_script_interrupted_terminate(self): fut = self.loop.create_task( self.con.execute('''SELECT pg_sleep(10)''')) await asyncio.sleep(0.2) self.assertFalse(self.con.is_closed()) self.con.terminate() self.assertTrue(self.con.is_closed()) with self.assertRaisesRegex(asyncpg.ConnectionDoesNotExistError, 'closed in the middle'): await fut self.con.terminate() class TestExecuteMany(tb.ConnectedTestCase): def setUp(self): super().setUp() self.loop.run_until_complete(self.con.execute( 'CREATE TABLE exmany (a text, b int PRIMARY KEY)')) def tearDown(self): self.loop.run_until_complete(self.con.execute('DROP TABLE exmany')) super().tearDown() async def test_executemany_basic(self): result = await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) self.assertIsNone(result) result = await self.con.fetch(''' SELECT * FROM exmany ''') self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) # Empty set await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', ()) result = await self.con.fetch(''' SELECT * FROM exmany ''') self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) async def test_executemany_returning(self): result = await self.con.fetchmany(''' INSERT INTO exmany VALUES($1, $2) RETURNING a, b ''', [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) result = await self.con.fetch(''' SELECT * FROM exmany ''') self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) # Empty set await self.con.fetchmany(''' INSERT INTO exmany VALUES($1, $2) RETURNING a, b ''', ()) result = await self.con.fetch(''' SELECT * FROM exmany ''') self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) # Without "RETURNING" result = await self.con.fetchmany(''' INSERT INTO exmany VALUES($1, $2) ''', [('e', 5), ('f', 6)]) self.assertEqual(result, []) result = await self.con.fetch(''' SELECT * FROM exmany ''') self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6) ]) async def test_executemany_bad_input(self): with self.assertRaisesRegex( exceptions.DataError, r"invalid input in executemany\(\) argument sequence element #1: " r"expected a sequence", ): await self.con.executemany(''' INSERT INTO exmany (b) VALUES($1) ''', [(0,), {1: 0}]) with self.assertRaisesRegex( exceptions.DataError, r"invalid input for query argument \$1 in element #1 of " r"executemany\(\) sequence: 'bad'", ): await self.con.executemany(''' INSERT INTO exmany (b) VALUES($1) ''', [(0,), ("bad",)]) async def test_executemany_error_in_input_gen(self): bad_data = ([1 / 0] for v in range(10)) with self.assertRaises(ZeroDivisionError): async with self.con.transaction(): await self.con.executemany(''' INSERT INTO exmany (b)VALUES($1) ''', bad_data) good_data = ([v] for v in range(10)) async with self.con.transaction(): await self.con.executemany(''' INSERT INTO exmany (b)VALUES($1) ''', good_data) async def test_executemany_server_failure(self): with self.assertRaises(exceptions.UniqueViolationError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', [ ('a', 1), ('b', 2), ('c', 2), ('d', 4) ]) result = await self.con.fetch('SELECT * FROM exmany') self.assertEqual(result, []) async def test_executemany_server_failure_after_writes(self): with self.assertRaises(exceptions.UniqueViolationError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', [('a' * 32768, x) for x in range(10)] + [ ('b', 12), ('c', 12), ('d', 14) ]) result = await self.con.fetch('SELECT b FROM exmany') self.assertEqual(result, []) async def test_executemany_server_failure_during_writes(self): # failure at the beginning, server error detected in the middle pos = 0 def gen(): nonlocal pos while pos < 128: pos += 1 if pos < 3: yield ('a', 0) else: yield 'a' * 32768, pos with self.assertRaises(exceptions.UniqueViolationError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', gen()) result = await self.con.fetch('SELECT b FROM exmany') self.assertEqual(result, []) self.assertLess(pos, 128, 'should stop early') async def test_executemany_client_failure_after_writes(self): with self.assertRaises(ZeroDivisionError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', (('a' * 32768, y + y / y) for y in range(10, -1, -1))) result = await self.con.fetch('SELECT b FROM exmany') self.assertEqual(result, []) async def test_executemany_timeout(self): with self.assertRaises(asyncio.TimeoutError): await self.con.executemany(''' INSERT INTO exmany VALUES(pg_sleep(0.1) || $1, $2) ''', [('a' * 32768, x) for x in range(128)], timeout=0.5) result = await self.con.fetch('SELECT * FROM exmany') self.assertEqual(result, []) async def test_executemany_timeout_flow_control(self): event = asyncio.Event() async def locker(): test_func = getattr(self, self._testMethodName).__func__ opts = getattr(test_func, '__connect_options__', {}) conn = await self.connect(**opts) try: tx = conn.transaction() await tx.start() await conn.execute("UPDATE exmany SET a = '1' WHERE b = 10") event.set() await asyncio.sleep(1) await tx.rollback() finally: event.set() await conn.close() await self.con.executemany(''' INSERT INTO exmany VALUES(NULL, $1) ''', [(x,) for x in range(128)]) fut = asyncio.ensure_future(locker()) await event.wait() with self.assertRaises(asyncio.TimeoutError): await self.con.executemany(''' UPDATE exmany SET a = $1 WHERE b = $2 ''', [('a' * 32768, x) for x in range(128)], timeout=0.5) await fut result = await self.con.fetch( 'SELECT * FROM exmany WHERE a IS NOT NULL') self.assertEqual(result, []) async def test_executemany_client_failure_in_transaction(self): tx = self.con.transaction() await tx.start() with self.assertRaises(ZeroDivisionError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', (('a' * 32768, y + y / y) for y in range(10, -1, -1))) result = await self.con.fetch('SELECT b FROM exmany') # only 2 batches executed (2 x 4) self.assertEqual( [x[0] for x in result], [y + 1 for y in range(10, 2, -1)]) await tx.rollback() result = await self.con.fetch('SELECT b FROM exmany') self.assertEqual(result, []) async def test_executemany_client_server_failure_conflict(self): self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64) with self.assertRaises(exceptions.UniqueViolationError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, 0) ''', (('a' * 32768,) for y in range(4, -1, -1) if y / y)) result = await self.con.fetch('SELECT b FROM exmany') self.assertEqual(result, []) async def test_executemany_prepare(self): stmt = await self.con.prepare(''' INSERT INTO exmany VALUES($1, $2) ''') result = await stmt.executemany([ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) self.assertIsNone(result) result = await self.con.fetch(''' SELECT * FROM exmany ''') self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) # Empty set await stmt.executemany(()) result = await self.con.fetch(''' SELECT * FROM exmany ''') self.assertEqual(result, [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) ================================================ FILE: tests/test_introspection.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import json from asyncpg import _testbase as tb from asyncpg import connection as apg_con MAX_RUNTIME = 0.25 class SlowIntrospectionConnection(apg_con.Connection): """Connection class to test introspection races.""" introspect_count = 0 async def _introspect_types(self, *args, **kwargs): self.introspect_count += 1 await asyncio.sleep(0.4) return await super()._introspect_types(*args, **kwargs) class TestIntrospection(tb.ConnectedTestCase): @classmethod def setUpClass(cls): super().setUpClass() cls.adminconn = cls.loop.run_until_complete(cls.connect()) cls.loop.run_until_complete( cls.adminconn.execute('CREATE DATABASE asyncpg_intro_test')) @classmethod def tearDownClass(cls): cls.loop.run_until_complete( cls.adminconn.execute('DROP DATABASE asyncpg_intro_test')) cls.loop.run_until_complete(cls.adminconn.close()) cls.adminconn = None super().tearDownClass() @classmethod def get_server_settings(cls): settings = super().get_server_settings() settings.pop('jit', None) return settings def setUp(self): super().setUp() self.loop.run_until_complete(self._add_custom_codec(self.con)) async def _add_custom_codec(self, conn): # mess up with the codec - builtin introspection shouldn't be affected await conn.set_type_codec( "oid", schema="pg_catalog", encoder=lambda value: None, decoder=lambda value: None, format="text", ) @tb.with_connection_options(database='asyncpg_intro_test') async def test_introspection_on_large_db(self): await self.con.execute( 'CREATE TABLE base ({})'.format( ','.join('c{:02} varchar'.format(n) for n in range(50)) ) ) for n in range(1000): await self.con.execute( 'CREATE TABLE child_{:04} () inherits (base)'.format(n) ) with self.assertRunUnder(MAX_RUNTIME): await self.con.fetchval('SELECT $1::int[]', [1, 2]) @tb.with_connection_options(statement_cache_size=0) async def test_introspection_no_stmt_cache_01(self): old_uid = apg_con._uid self.assertEqual(self.con._stmt_cache.get_max_size(), 0) await self.con.fetchval('SELECT $1::int[]', [1, 2]) await self.con.execute(''' CREATE EXTENSION IF NOT EXISTS hstore ''') try: await self.con.set_builtin_type_codec( 'hstore', codec_name='pg_contrib.hstore') finally: await self.con.execute(''' DROP EXTENSION hstore ''') self.assertEqual(apg_con._uid, old_uid) @tb.with_connection_options(max_cacheable_statement_size=1) async def test_introspection_no_stmt_cache_02(self): # max_cacheable_statement_size will disable caching both for # the user query and for the introspection query. old_uid = apg_con._uid await self.con.fetchval('SELECT $1::int[]', [1, 2]) await self.con.execute(''' CREATE EXTENSION IF NOT EXISTS hstore ''') try: await self.con.set_builtin_type_codec( 'hstore', codec_name='pg_contrib.hstore') finally: await self.con.execute(''' DROP EXTENSION hstore ''') self.assertEqual(apg_con._uid, old_uid) @tb.with_connection_options(max_cacheable_statement_size=10000) async def test_introspection_no_stmt_cache_03(self): # max_cacheable_statement_size will disable caching for # the user query but not for the introspection query. old_uid = apg_con._uid await self.con.fetchval( "SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2]) self.assertGreater(apg_con._uid, old_uid) async def test_introspection_sticks_for_ps(self): # Test that the introspected codec pipeline for a prepared # statement is not affected by a subsequent codec cache bust. ps = await self.con._prepare('SELECT $1::json[]', use_cache=True) try: # Setting a custom codec blows the codec cache for derived types. await self.con.set_type_codec( 'json', encoder=lambda v: v, decoder=json.loads, schema='pg_catalog', format='text' ) # The originally prepared statement should still be OK and # use the previously selected codec. self.assertEqual(await ps.fetchval(['{"foo": 1}']), ['{"foo": 1}']) # The new query uses the custom codec. v = await self.con.fetchval('SELECT $1::json[]', ['{"foo": 1}']) self.assertEqual(v, [{'foo': 1}]) finally: await self.con.reset_type_codec( 'json', schema='pg_catalog') async def test_introspection_retries_after_cache_bust(self): # Test that codec cache bust racing with the introspection # query would cause introspection to retry. slow_intro_conn = await self.connect( connection_class=SlowIntrospectionConnection) await self._add_custom_codec(slow_intro_conn) try: await self.con.execute(''' CREATE DOMAIN intro_1_t AS int; CREATE DOMAIN intro_2_t AS int; ''') await slow_intro_conn.fetchval(''' SELECT $1::intro_1_t ''', 10) # slow_intro_conn cache is now populated with intro_1_t async def wait_and_drop(): await asyncio.sleep(0.1) await slow_intro_conn.reload_schema_state() # Now, in parallel, run another query that # references both intro_1_t and intro_2_t. await asyncio.gather( slow_intro_conn.fetchval(''' SELECT $1::intro_1_t, $2::intro_2_t ''', 10, 20), wait_and_drop() ) # Initial query + two tries for the second query. self.assertEqual(slow_intro_conn.introspect_count, 3) finally: await self.con.execute(''' DROP DOMAIN intro_1_t; DROP DOMAIN intro_2_t; ''') await slow_intro_conn.close() @tb.with_connection_options(database='asyncpg_intro_test') async def test_introspection_loads_basetypes_of_domains(self): # Test that basetypes of domains are loaded to the # client encode/decode cache await self.con.execute(''' DROP TABLE IF EXISTS test; DROP DOMAIN IF EXISTS num_array; CREATE DOMAIN num_array numeric[]; CREATE TABLE test ( num num_array ); ''') try: # if domain basetypes are not loaded, this insert will fail await self.con.execute( 'INSERT INTO test (num) VALUES ($1)', ([1, 2],)) finally: await self.con.execute(''' DROP TABLE IF EXISTS test; DROP DOMAIN IF EXISTS num_array; ''') ================================================ FILE: tests/test_listeners.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import os import platform import unittest from asyncpg import _testbase as tb from asyncpg import exceptions class TestListeners(tb.ClusterTestCase): async def test_listen_01(self): async with self.create_pool(database='postgres') as pool: async with pool.acquire() as con: q1 = asyncio.Queue() q2 = asyncio.Queue() q3 = asyncio.Queue() def listener1(*args): q1.put_nowait(args) def listener2(*args): q2.put_nowait(args) async def async_listener3(*args): q3.put_nowait(args) await con.add_listener('test', listener1) await con.add_listener('test', listener2) await con.add_listener('test', async_listener3) await con.execute("NOTIFY test, 'aaaa'") self.assertEqual( await q1.get(), (con, con.get_server_pid(), 'test', 'aaaa')) self.assertEqual( await q2.get(), (con, con.get_server_pid(), 'test', 'aaaa')) self.assertEqual( await q3.get(), (con, con.get_server_pid(), 'test', 'aaaa')) await con.remove_listener('test', listener2) await con.remove_listener('test', async_listener3) await con.execute("NOTIFY test, 'aaaa'") self.assertEqual( await q1.get(), (con, con.get_server_pid(), 'test', 'aaaa')) with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for(q2.get(), timeout=0.05) await con.reset() await con.remove_listener('test', listener1) await con.execute("NOTIFY test, 'aaaa'") with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for(q1.get(), timeout=0.05) with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for(q2.get(), timeout=0.05) async def test_listen_02(self): async with self.create_pool(database='postgres') as pool: async with pool.acquire() as con1, pool.acquire() as con2: q1 = asyncio.Queue() def listener1(*args): q1.put_nowait(args) await con1.add_listener('ipc', listener1) await con2.execute("NOTIFY ipc, 'hello'") self.assertEqual( await q1.get(), (con1, con2.get_server_pid(), 'ipc', 'hello')) await con1.remove_listener('ipc', listener1) async def test_listen_notletters(self): async with self.create_pool(database='postgres') as pool: async with pool.acquire() as con1, pool.acquire() as con2: q1 = asyncio.Queue() def listener1(*args): q1.put_nowait(args) await con1.add_listener('12+"34', listener1) await con2.execute("""NOTIFY "12+""34", 'hello'""") self.assertEqual( await q1.get(), (con1, con2.get_server_pid(), '12+"34', 'hello')) await con1.remove_listener('12+"34', listener1) async def test_dangling_listener_warns(self): async with self.create_pool(database='postgres') as pool: with self.assertWarnsRegex( exceptions.InterfaceWarning, '.*Connection.*is being released to the pool but ' 'has 1 active notification listener'): async with pool.acquire() as con: def listener1(*args): pass await con.add_listener('ipc', listener1) class TestLogListeners(tb.ConnectedTestCase): @tb.with_connection_options(server_settings={ 'client_min_messages': 'notice' }) async def test_log_listener_01(self): q1 = asyncio.Queue() q2 = asyncio.Queue() def notice_callb(con, message): # Message fields depend on PG version, hide some values. dct = message.as_dict() del dct['server_source_line'] q1.put_nowait((con, type(message), dct)) async def async_notice_callb(con, message): # Message fields depend on PG version, hide some values. dct = message.as_dict() del dct['server_source_line'] q2.put_nowait((con, type(message), dct)) async def raise_notice(): await self.con.execute( """DO $$ BEGIN RAISE NOTICE 'catch me!'; END; $$ LANGUAGE plpgsql""" ) async def raise_warning(): await self.con.execute( """DO $$ BEGIN RAISE WARNING 'catch me!'; END; $$ LANGUAGE plpgsql""" ) con = self.con con.add_log_listener(notice_callb) con.add_log_listener(async_notice_callb) expected_msg = { 'context': 'PL/pgSQL function inline_code_block line 2 at RAISE', 'message': 'catch me!', 'server_source_function': 'exec_stmt_raise', } expected_msg_notice = { **expected_msg, 'severity': 'NOTICE', 'severity_en': 'NOTICE', 'sqlstate': '00000', } expected_msg_warn = { **expected_msg, 'severity': 'WARNING', 'severity_en': 'WARNING', 'sqlstate': '01000', } if con.get_server_version() < (9, 6): del expected_msg_notice['context'] del expected_msg_notice['severity_en'] del expected_msg_warn['context'] del expected_msg_warn['severity_en'] await raise_notice() await raise_warning() msg = await q1.get() msg[2].pop('server_source_filename', None) self.assertEqual( msg, (con, exceptions.PostgresLogMessage, expected_msg_notice)) msg = await q1.get() msg[2].pop('server_source_filename', None) self.assertEqual( msg, (con, exceptions.PostgresWarning, expected_msg_warn)) msg = await q2.get() msg[2].pop('server_source_filename', None) self.assertEqual( msg, (con, exceptions.PostgresLogMessage, expected_msg_notice)) msg = await q2.get() msg[2].pop('server_source_filename', None) self.assertEqual( msg, (con, exceptions.PostgresWarning, expected_msg_warn)) con.remove_log_listener(notice_callb) con.remove_log_listener(async_notice_callb) await raise_notice() self.assertTrue(q1.empty()) con.add_log_listener(notice_callb) await raise_notice() await q1.get() self.assertTrue(q1.empty()) await con.reset() await raise_notice() self.assertTrue(q1.empty()) @tb.with_connection_options(server_settings={ 'client_min_messages': 'notice' }) async def test_log_listener_02(self): q1 = asyncio.Queue() cur_id = None def notice_callb(con, message): q1.put_nowait((con, cur_id, message.message)) con = self.con await con.execute( "CREATE FUNCTION _test(i INT) RETURNS int LANGUAGE plpgsql AS $$" " BEGIN" " RAISE NOTICE '1_%', i;" " PERFORM pg_sleep(0.1);" " RAISE NOTICE '2_%', i;" " RETURN i;" " END" "$$" ) try: con.add_log_listener(notice_callb) for cur_id in range(10): await con.execute("SELECT _test($1)", cur_id) for cur_id in range(10): self.assertEqual( q1.get_nowait(), (con, cur_id, '1_%s' % cur_id)) self.assertEqual( q1.get_nowait(), (con, cur_id, '2_%s' % cur_id)) con.remove_log_listener(notice_callb) self.assertTrue(q1.empty()) finally: await con.execute('DROP FUNCTION _test(i INT)') @tb.with_connection_options(server_settings={ 'client_min_messages': 'notice' }) async def test_log_listener_03(self): q1 = asyncio.Queue() async def raise_message(level, code): await self.con.execute(""" DO $$ BEGIN RAISE {} 'catch me!' USING ERRCODE = '{}'; END; $$ LANGUAGE plpgsql; """.format(level, code)) def notice_callb(con, message): # Message fields depend on PG version, hide some values. q1.put_nowait(message) self.con.add_log_listener(notice_callb) await raise_message('WARNING', '99999') msg = await q1.get() self.assertIsInstance(msg, exceptions.PostgresWarning) self.assertEqual(msg.sqlstate, '99999') await raise_message('WARNING', '01004') msg = await q1.get() self.assertIsInstance(msg, exceptions.StringDataRightTruncation) self.assertEqual(msg.sqlstate, '01004') with self.assertRaises(exceptions.InvalidCharacterValueForCastError): await raise_message('', '22018') self.assertTrue(q1.empty()) async def test_dangling_log_listener_warns(self): async with self.create_pool(database='postgres') as pool: with self.assertWarnsRegex( exceptions.InterfaceWarning, '.*Connection.*is being released to the pool but ' 'has 1 active log listener'): async with pool.acquire() as con: def listener1(*args): pass con.add_log_listener(listener1) @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') @unittest.skipIf( platform.system() == 'Windows', 'not compatible with ProactorEventLoop which is default in Python 3.8+') class TestConnectionTerminationListener(tb.ProxiedClusterTestCase): async def test_connection_termination_callback_called_on_remote(self): called = False async_called = False def close_cb(con): nonlocal called called = True async def async_close_cb(con): nonlocal async_called async_called = True con = await self.connect() con.add_termination_listener(close_cb) con.add_termination_listener(async_close_cb) self.proxy.close_all_connections() try: await con.fetchval('SELECT 1') except Exception: pass self.assertTrue(called) self.assertTrue(async_called) async def test_connection_termination_callback_called_on_local(self): called = False def close_cb(con): nonlocal called called = True con = await self.connect() con.add_termination_listener(close_cb) await con.close() await asyncio.sleep(0) self.assertTrue(called) ================================================ FILE: tests/test_logging.py ================================================ import asyncio from asyncpg import _testbase as tb from asyncpg import exceptions class LogCollector: def __init__(self): self.records = [] def __call__(self, record): self.records.append(record) class TestQueryLogging(tb.ConnectedTestCase): async def test_logging_context(self): queries = asyncio.Queue() def query_saver(record): queries.put_nowait(record) log = LogCollector() with self.con.query_logger(query_saver): self.assertEqual(len(self.con._query_loggers), 1) await self.con.execute("SELECT 1") with self.con.query_logger(log): self.assertEqual(len(self.con._query_loggers), 2) await self.con.execute("SELECT 2") r1 = await queries.get() r2 = await queries.get() self.assertEqual(r1.query, "SELECT 1") self.assertEqual(r2.query, "SELECT 2") self.assertEqual(len(log.records), 1) self.assertEqual(log.records[0].query, "SELECT 2") self.assertEqual(len(self.con._query_loggers), 0) async def test_error_logging(self): log = LogCollector() with self.con.query_logger(log): with self.assertRaises(exceptions.UndefinedColumnError): await self.con.execute("SELECT x") await asyncio.sleep(0) # wait for logging self.assertEqual(len(log.records), 1) self.assertEqual( type(log.records[0].exception), exceptions.UndefinedColumnError ) ================================================ FILE: tests/test_pool.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import inspect import os import pathlib import platform import random import textwrap import time import unittest import asyncpg from asyncpg import _testbase as tb from asyncpg import connection as pg_connection from asyncpg import pool as pg_pool from asyncpg import cluster as pg_cluster _system = platform.uname().system POOL_NOMINAL_TIMEOUT = 0.5 class SlowResetConnection(pg_connection.Connection): """Connection class to simulate races with Connection.reset().""" async def reset(self, *, timeout=None): await asyncio.sleep(0.2) return await super().reset(timeout=timeout) class SlowCancelConnection(pg_connection.Connection): """Connection class to simulate races with Connection._cancel().""" async def _cancel(self, waiter): await asyncio.sleep(0.2) return await super()._cancel(waiter) class TestPool(tb.ConnectedTestCase): async def test_pool_01(self): for n in {1, 5, 10, 20, 100}: with self.subTest(tasksnum=n): pool = await self.create_pool(database='postgres', min_size=5, max_size=10) async def worker(): con = await pool.acquire() self.assertEqual(await con.fetchval('SELECT 1'), 1) await pool.release(con) tasks = [worker() for _ in range(n)] await asyncio.gather(*tasks) await pool.close() async def test_pool_02(self): for n in {1, 3, 5, 10, 20, 100}: with self.subTest(tasksnum=n): async with self.create_pool(database='postgres', min_size=5, max_size=5) as pool: async def worker(): con = await pool.acquire(timeout=5) self.assertEqual(await con.fetchval('SELECT 1'), 1) await pool.release(con) tasks = [worker() for _ in range(n)] await asyncio.gather(*tasks) async def test_pool_03(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) con = await pool.acquire(timeout=1) with self.assertRaises(asyncio.TimeoutError): await pool.acquire(timeout=0.03) pool.terminate() del con async def test_pool_04(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) # Manual termination of pool connections releases the # pool item immediately. con.terminate() self.assertIsNone(pool._holders[0]._con) self.assertIsNone(pool._holders[0]._in_use) con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) self.assertEqual(await con.fetchval('SELECT 1'), 1) await con.close() self.assertIsNone(pool._holders[0]._con) self.assertIsNone(pool._holders[0]._in_use) # Calling release should not hurt. await pool.release(con) pool.terminate() async def test_pool_05(self): for n in {1, 3, 5, 10, 20, 100}: with self.subTest(tasksnum=n): pool = await self.create_pool(database='postgres', min_size=5, max_size=10) async def worker(): async with pool.acquire() as con: self.assertEqual(await con.fetchval('SELECT 1'), 1) tasks = [worker() for _ in range(n)] await asyncio.gather(*tasks) await pool.close() async def test_pool_06(self): fut = asyncio.Future() async def setup(con): fut.set_result(con) async with self.create_pool(database='postgres', min_size=5, max_size=5, setup=setup) as pool: async with pool.acquire() as con: pass self.assertIs(con, await fut) async def test_pool_07(self): cons = set() connect_called = 0 init_called = 0 setup_called = 0 reset_called = 0 async def connect(*args, **kwargs): nonlocal connect_called connect_called += 1 return await pg_connection.connect(*args, **kwargs) async def setup(con): nonlocal setup_called if con._con not in cons: # `con` is `PoolConnectionProxy`. raise RuntimeError('init was not called before setup') setup_called += 1 async def init(con): nonlocal init_called if con in cons: raise RuntimeError('init was called more than once') cons.add(con) init_called += 1 async def reset(con): nonlocal reset_called reset_called += 1 async def user(pool): async with pool.acquire() as con: if con._con not in cons: # `con` is `PoolConnectionProxy`. raise RuntimeError('init was not called') async with self.create_pool(database='postgres', min_size=2, max_size=5, connect=connect, init=init, setup=setup, reset=reset) as pool: users = asyncio.gather(*[user(pool) for _ in range(10)]) await users self.assertEqual(len(cons), 5) self.assertEqual(connect_called, 5) self.assertEqual(init_called, 5) self.assertEqual(setup_called, 10) self.assertEqual(reset_called, 10) async def bad_connect(*args, **kwargs): return 1 with self.assertRaisesRegex( asyncpg.InterfaceError, "expected pool connect callback to return an instance of " "'asyncpg\\.connection\\.Connection', got 'int'" ): await self.create_pool(database='postgres', connect=bad_connect) async def test_pool_08(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) with self.assertRaisesRegex(asyncpg.InterfaceError, 'is not a member'): await pool.release(con._con) async def test_pool_09(self): pool1 = await self.create_pool(database='postgres', min_size=1, max_size=1) pool2 = await self.create_pool(database='postgres', min_size=1, max_size=1) try: con = await pool1.acquire(timeout=POOL_NOMINAL_TIMEOUT) with self.assertRaisesRegex(asyncpg.InterfaceError, 'is not a member'): await pool2.release(con) finally: await pool1.release(con) await pool1.close() await pool2.close() async def test_pool_10(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) con = await pool.acquire() await pool.release(con) await pool.release(con) await pool.close() async def test_pool_11(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) async with pool.acquire() as con: self.assertIn(repr(con._con), repr(con)) # Test __repr__. ps = await con.prepare('SELECT 1') txn = con.transaction() async with con.transaction(): cur = await con.cursor('SELECT 1') ps_cur = await ps.cursor() self.assertIn('[released]', repr(con)) with self.assertRaisesRegex( asyncpg.InterfaceError, r'cannot call Connection\.execute.*released back to the pool'): con.execute('select 1') for meth in ('fetchval', 'fetchrow', 'fetch', 'explain', 'get_query', 'get_statusmsg', 'get_parameters', 'get_attributes'): with self.assertRaisesRegex( asyncpg.InterfaceError, r'cannot call PreparedStatement\.{meth}.*released ' r'back to the pool'.format(meth=meth)): getattr(ps, meth)() for c in (cur, ps_cur): for meth in ('fetch', 'fetchrow'): with self.assertRaisesRegex( asyncpg.InterfaceError, r'cannot call Cursor\.{meth}.*released ' r'back to the pool'.format(meth=meth)): getattr(c, meth)() with self.assertRaisesRegex( asyncpg.InterfaceError, r'cannot call Cursor\.forward.*released ' r'back to the pool'): c.forward(1) for meth in ('start', 'commit', 'rollback'): with self.assertRaisesRegex( asyncpg.InterfaceError, r'cannot call Transaction\.{meth}.*released ' r'back to the pool'.format(meth=meth)): getattr(txn, meth)() await pool.close() async def test_pool_12(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) async with pool.acquire() as con: self.assertTrue(isinstance(con, pg_connection.Connection)) self.assertFalse(isinstance(con, list)) await pool.close() async def test_pool_13(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) async with pool.acquire() as con: self.assertIn('Execute an SQL command', con.execute.__doc__) self.assertEqual(con.execute.__name__, 'execute') self.assertIn( str(inspect.signature(con.execute))[1:], str(inspect.signature(pg_connection.Connection.execute))) await pool.close() def test_pool_init_run_until_complete(self): pool_init = self.create_pool(database='postgres') pool = self.loop.run_until_complete(pool_init) self.assertIsInstance(pool, asyncpg.pool.Pool) async def test_pool_exception_in_setup_and_init(self): class Error(Exception): pass async def setup(con): nonlocal setup_calls, last_con last_con = con setup_calls += 1 if setup_calls > 1: cons.append(con) else: cons.append('error') raise Error with self.subTest(method='setup'): setup_calls = 0 last_con = None cons = [] async with self.create_pool(database='postgres', min_size=1, max_size=1, setup=setup) as pool: with self.assertRaises(Error): await pool.acquire() self.assertTrue(last_con.is_closed()) async with pool.acquire() as con: self.assertEqual(cons, ['error', con]) with self.subTest(method='init'): setup_calls = 0 last_con = None cons = [] async with self.create_pool(database='postgres', min_size=0, max_size=1, init=setup) as pool: with self.assertRaises(Error): await pool.acquire() self.assertTrue(last_con.is_closed()) async with pool.acquire() as con: self.assertEqual(await con.fetchval('select 1::int'), 1) self.assertEqual(cons, ['error', con._con]) async def test_pool_auth(self): if not self.cluster.is_managed(): self.skipTest('unmanaged cluster') self.cluster.reset_hba() if _system != 'Windows': self.cluster.add_hba_entry( type='local', database='postgres', user='pooluser', auth_method='md5') self.cluster.add_hba_entry( type='host', address='127.0.0.1/32', database='postgres', user='pooluser', auth_method='md5') self.cluster.add_hba_entry( type='host', address='::1/128', database='postgres', user='pooluser', auth_method='md5') self.cluster.reload() try: await self.con.execute(''' CREATE ROLE pooluser WITH LOGIN PASSWORD 'poolpassword' ''') pool = await self.create_pool(database='postgres', user='pooluser', password='poolpassword', min_size=5, max_size=10) async def worker(): con = await pool.acquire() self.assertEqual(await con.fetchval('SELECT 1'), 1) await pool.release(con) tasks = [worker() for _ in range(5)] await asyncio.gather(*tasks) await pool.close() finally: await self.con.execute('DROP ROLE pooluser') # Reset cluster's pg_hba.conf since we've meddled with it self.cluster.trust_local_connections() self.cluster.reload() async def test_pool_handles_task_cancel_in_acquire_with_timeout(self): # See https://github.com/MagicStack/asyncpg/issues/547 pool = await self.create_pool(database='postgres', min_size=1, max_size=1) async def worker(): async with pool.acquire(timeout=100): pass # Schedule task task = self.loop.create_task(worker()) # Yield to task, but cancel almost immediately await asyncio.sleep(0.00000000001) # Cancel the worker. task.cancel() # Wait to make sure the cleanup has completed. await asyncio.sleep(0.4) # Check that the connection has been returned to the pool. self.assertEqual(pool._queue.qsize(), 1) async def test_pool_handles_task_cancel_in_release(self): # Use SlowResetConnectionPool to simulate # the Task.cancel() and __aexit__ race. pool = await self.create_pool(database='postgres', min_size=1, max_size=1, connection_class=SlowResetConnection) async def worker(): async with pool.acquire(): pass task = self.loop.create_task(worker()) # Let the worker() run. await asyncio.sleep(0.1) # Cancel the worker. task.cancel() # Wait to make sure the cleanup has completed. await asyncio.sleep(0.4) # Check that the connection has been returned to the pool. self.assertEqual(pool._queue.qsize(), 1) async def test_pool_handles_query_cancel_in_release(self): # Use SlowResetConnectionPool to simulate # the Task.cancel() and __aexit__ race. pool = await self.create_pool(database='postgres', min_size=1, max_size=1, connection_class=SlowCancelConnection) async def worker(): async with pool.acquire() as con: await con.execute('SELECT pg_sleep(10)') task = self.loop.create_task(worker()) # Let the worker() run. await asyncio.sleep(0.1) # Cancel the worker. task.cancel() # Wait to make sure the cleanup has completed. await asyncio.sleep(0.5) # Check that the connection has been returned to the pool. self.assertEqual(pool._queue.qsize(), 1) async def test_pool_no_acquire_deadlock(self): async with self.create_pool(database='postgres', min_size=1, max_size=1, max_queries=1) as pool: async def sleep_and_release(): async with pool.acquire() as con: await con.execute('SELECT pg_sleep(1)') asyncio.ensure_future(sleep_and_release()) await asyncio.sleep(0.5) async with pool.acquire() as con: await con.fetchval('SELECT 1') async def test_pool_config_persistence(self): N = 100 cons = set() class MyConnection(asyncpg.Connection): async def foo(self): return 42 async def fetchval(self, query): res = await super().fetchval(query) return res + 1 async def test(pool): async with pool.acquire() as con: self.assertEqual(await con.fetchval('SELECT 1'), 2) self.assertEqual(await con.foo(), 42) self.assertTrue(isinstance(con, MyConnection)) self.assertEqual(con._con._config.statement_cache_size, 3) cons.add(con) async with self.create_pool( database='postgres', min_size=10, max_size=10, max_queries=1, connection_class=MyConnection, statement_cache_size=3) as pool: await asyncio.gather(*[test(pool) for _ in range(N)]) self.assertEqual(len(cons), N) async def test_pool_release_in_xact(self): """Test that Connection.reset() closes any open transaction.""" async with self.create_pool(database='postgres', min_size=1, max_size=1) as pool: async def get_xact_id(con): return await con.fetchval('select txid_current()') with self.assertLoopErrorHandlerCalled('an active transaction'): async with pool.acquire() as con: real_con = con._con # unwrap PoolConnectionProxy id1 = await get_xact_id(con) tr = con.transaction() self.assertIsNone(con._con._top_xact) await tr.start() self.assertIs(real_con._top_xact, tr) id2 = await get_xact_id(con) self.assertNotEqual(id1, id2) self.assertIsNone(real_con._top_xact) async with pool.acquire() as con: self.assertIs(con._con, real_con) self.assertIsNone(con._con._top_xact) id3 = await get_xact_id(con) self.assertNotEqual(id2, id3) async def test_pool_connection_methods(self): async def test_fetch(pool): i = random.randint(0, 20) await asyncio.sleep(random.random() / 100) r = await pool.fetch('SELECT {}::int'.format(i)) self.assertEqual(r, [(i,)]) return 1 async def test_fetchrow(pool): i = random.randint(0, 20) await asyncio.sleep(random.random() / 100) r = await pool.fetchrow('SELECT {}::int'.format(i)) self.assertEqual(r, (i,)) return 1 async def test_fetchval(pool): i = random.randint(0, 20) await asyncio.sleep(random.random() / 100) r = await pool.fetchval('SELECT {}::int'.format(i)) self.assertEqual(r, i) return 1 async def test_execute(pool): await asyncio.sleep(random.random() / 100) r = await pool.execute('SELECT generate_series(0, 10)') self.assertEqual(r, 'SELECT {}'.format(11)) return 1 async def test_execute_with_arg(pool): i = random.randint(0, 20) await asyncio.sleep(random.random() / 100) r = await pool.execute('SELECT generate_series(0, $1)', i) self.assertEqual(r, 'SELECT {}'.format(i + 1)) return 1 async def run(N, meth): async with self.create_pool(database='postgres', min_size=5, max_size=10) as pool: coros = [meth(pool) for _ in range(N)] res = await asyncio.gather(*coros) self.assertEqual(res, [1] * N) methods = [test_fetch, test_fetchrow, test_fetchval, test_execute, test_execute_with_arg] with tb.silence_asyncio_long_exec_warning(): for method in methods: with self.subTest(method=method.__name__): await run(200, method) async def test_pool_connection_execute_many(self): async def worker(pool): await asyncio.sleep(random.random() / 100) await pool.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', [ ('a', 1), ('b', 2), ('c', 3), ('d', 4) ]) return 1 N = 200 async with self.create_pool(database='postgres', min_size=5, max_size=10) as pool: await pool.execute('CREATE TABLE exmany (a text, b int)') try: coros = [worker(pool) for _ in range(N)] res = await asyncio.gather(*coros) self.assertEqual(res, [1] * N) n_rows = await pool.fetchval('SELECT count(*) FROM exmany') self.assertEqual(n_rows, N * 4) finally: await pool.execute('DROP TABLE exmany') async def test_pool_max_inactive_time_01(self): async with self.create_pool( database='postgres', min_size=1, max_size=1, max_inactive_connection_lifetime=0.1) as pool: # Test that it's OK if a query takes longer time to execute # than `max_inactive_connection_lifetime`. con = pool._holders[0]._con for _ in range(3): await pool.execute('SELECT pg_sleep(0.5)') self.assertIs(pool._holders[0]._con, con) self.assertEqual( await pool.execute('SELECT 1::int'), 'SELECT 1') self.assertIs(pool._holders[0]._con, con) async def test_pool_max_inactive_time_02(self): async with self.create_pool( database='postgres', min_size=1, max_size=1, max_inactive_connection_lifetime=0.5) as pool: # Test that we have a new connection after pool not # being used longer than `max_inactive_connection_lifetime`. con = pool._holders[0]._con self.assertEqual( await pool.execute('SELECT 1::int'), 'SELECT 1') self.assertIs(pool._holders[0]._con, con) await asyncio.sleep(1) self.assertIs(pool._holders[0]._con, None) self.assertEqual( await pool.execute('SELECT 1::int'), 'SELECT 1') self.assertIsNot(pool._holders[0]._con, con) async def test_pool_max_inactive_time_03(self): async with self.create_pool( database='postgres', min_size=1, max_size=1, max_inactive_connection_lifetime=1) as pool: # Test that we start counting inactive time *after* # the connection is being released back to the pool. con = pool._holders[0]._con await pool.execute('SELECT pg_sleep(0.5)') await asyncio.sleep(0.6) self.assertIs(pool._holders[0]._con, con) self.assertEqual( await pool.execute('SELECT 1::int'), 'SELECT 1') self.assertIs(pool._holders[0]._con, con) async def test_pool_max_inactive_time_04(self): # Chaos test for max_inactive_connection_lifetime. DURATION = 2.0 START = time.monotonic() N = 0 async def worker(pool): nonlocal N await asyncio.sleep(random.random() / 10 + 0.1) async with pool.acquire() as con: if random.random() > 0.5: await con.execute('SELECT pg_sleep({:.2f})'.format( random.random() / 10)) self.assertEqual( await con.fetchval('SELECT 42::int'), 42) if time.monotonic() - START < DURATION: await worker(pool) N += 1 async with self.create_pool( database='postgres', min_size=10, max_size=30, max_inactive_connection_lifetime=0.1) as pool: workers = [worker(pool) for _ in range(50)] await asyncio.gather(*workers) self.assertGreaterEqual(N, 50) async def test_pool_max_inactive_time_05(self): # Test that idle never-acquired connections abide by # the max inactive lifetime. async with self.create_pool( database='postgres', min_size=2, max_size=2, max_inactive_connection_lifetime=0.2) as pool: self.assertIsNotNone(pool._holders[0]._con) self.assertIsNotNone(pool._holders[1]._con) await pool.execute('SELECT pg_sleep(0.3)') await asyncio.sleep(0.3) self.assertIs(pool._holders[0]._con, None) # The connection in the second holder was never used, # but should be closed nonetheless. self.assertIs(pool._holders[1]._con, None) async def test_pool_handles_inactive_connection_errors(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) true_con = con._con await pool.release(con) # we simulate network error by terminating the connection true_con.terminate() # now pool should reopen terminated connection async with pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) as con: self.assertEqual(await con.fetchval('SELECT 1'), 1) await con.close() await pool.close() async def test_pool_size_and_capacity(self): async with self.create_pool( database='postgres', min_size=2, max_size=3, ) as pool: self.assertEqual(pool.get_min_size(), 2) self.assertEqual(pool.get_max_size(), 3) self.assertEqual(pool.get_size(), 2) self.assertEqual(pool.get_idle_size(), 2) async with pool.acquire(): self.assertEqual(pool.get_idle_size(), 1) async with pool.acquire(): self.assertEqual(pool.get_idle_size(), 0) async with pool.acquire(): self.assertEqual(pool.get_size(), 3) self.assertEqual(pool.get_idle_size(), 0) async def test_pool_closing(self): async with self.create_pool() as pool: self.assertFalse(pool.is_closing()) await pool.close() self.assertTrue(pool.is_closing()) async with self.create_pool() as pool: self.assertFalse(pool.is_closing()) pool.terminate() self.assertTrue(pool.is_closing()) async def test_pool_handles_transaction_exit_in_asyncgen_1(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) locals_ = {} exec(textwrap.dedent('''\ async def iterate(con): async with con.transaction(): for record in await con.fetch("SELECT 1"): yield record '''), globals(), locals_) iterate = locals_['iterate'] class MyException(Exception): pass with self.assertRaises(MyException): async with pool.acquire() as con: async for _ in iterate(con): # noqa raise MyException() async def test_pool_handles_transaction_exit_in_asyncgen_2(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) locals_ = {} exec(textwrap.dedent('''\ async def iterate(con): async with con.transaction(): for record in await con.fetch("SELECT 1"): yield record '''), globals(), locals_) iterate = locals_['iterate'] class MyException(Exception): pass with self.assertRaises(MyException): async with pool.acquire() as con: iterator = iterate(con) async for _ in iterator: # noqa raise MyException() del iterator async def test_pool_handles_asyncgen_finalization(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) locals_ = {} exec(textwrap.dedent('''\ async def iterate(con): for record in await con.fetch("SELECT 1"): yield record '''), globals(), locals_) iterate = locals_['iterate'] class MyException(Exception): pass with self.assertRaises(MyException): async with pool.acquire() as con: async with con.transaction(): async for _ in iterate(con): # noqa raise MyException() async def test_pool_close_waits_for_release(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) flag = self.loop.create_future() conn_released = False async def worker(): nonlocal conn_released async with pool.acquire() as connection: async with connection.transaction(): flag.set_result(True) await asyncio.sleep(0.1) conn_released = True self.loop.create_task(worker()) await flag await pool.close() self.assertTrue(conn_released) async def test_pool_close_timeout(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) flag = self.loop.create_future() async def worker(): async with pool.acquire(): flag.set_result(True) await asyncio.sleep(0.5) task = self.loop.create_task(worker()) with self.assertRaises(asyncio.TimeoutError): await flag await asyncio.wait_for(pool.close(), timeout=0.1) await task async def test_pool_expire_connections(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) con = await pool.acquire() try: await pool.expire_connections() finally: await pool.release(con) self.assertIsNone(pool._holders[0]._con) await pool.close() async def test_pool_set_connection_args(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1) # Test that connection is expired on release. con = await pool.acquire() connspec = self.get_connection_spec() try: connspec['server_settings']['application_name'] = \ 'set_conn_args_test' except KeyError: connspec['server_settings'] = { 'application_name': 'set_conn_args_test' } pool.set_connect_args(**connspec) await pool.expire_connections() await pool.release(con) con = await pool.acquire() self.assertEqual(con.get_settings().application_name, 'set_conn_args_test') await pool.release(con) # Test that connection is expired before acquire. connspec = self.get_connection_spec() try: connspec['server_settings']['application_name'] = \ 'set_conn_args_test' except KeyError: connspec['server_settings'] = { 'application_name': 'set_conn_args_test_2' } pool.set_connect_args(**connspec) await pool.expire_connections() con = await pool.acquire() self.assertEqual(con.get_settings().application_name, 'set_conn_args_test_2') await pool.release(con) await pool.close() async def test_pool_init_race(self): pool = self.create_pool(database='postgres', min_size=1, max_size=1) t1 = asyncio.ensure_future(pool) t2 = asyncio.ensure_future(pool) await t1 with self.assertRaisesRegex( asyncpg.InterfaceError, r'pool is being initialized in another task'): await t2 await pool.close() async def test_pool_init_and_use_race(self): pool = self.create_pool(database='postgres', min_size=1, max_size=1) pool_task = asyncio.ensure_future(pool) await asyncio.sleep(0) with self.assertRaisesRegex( asyncpg.InterfaceError, r'being initialized, but not yet ready'): await pool.fetchval('SELECT 1') await pool_task await pool.close() async def test_pool_remote_close(self): pool = await self.create_pool(min_size=1, max_size=1) backend_pid_fut = self.loop.create_future() async def worker(): async with pool.acquire() as conn: pool_backend_pid = await conn.fetchval( 'SELECT pg_backend_pid()') backend_pid_fut.set_result(pool_backend_pid) await asyncio.sleep(0.2) task = self.loop.create_task(worker()) try: conn = await self.connect() backend_pid = await backend_pid_fut await conn.execute('SELECT pg_terminate_backend($1)', backend_pid) finally: await conn.close() await task # Check that connection_lost has released the pool holder. conn = await pool.acquire(timeout=0.1) await pool.release(conn) @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') class TestPoolReconnectWithTargetSessionAttrs(tb.ClusterTestCase): @classmethod def setup_cluster(cls): cls.cluster = cls.new_cluster(pg_cluster.TempCluster) cls.start_cluster(cls.cluster) async def simulate_cluster_recovery_mode(self): port = self.cluster.get_connection_spec()['port'] await self.loop.run_in_executor( None, lambda: self.cluster.stop() ) # Simulate recovery mode (pathlib.Path(self.cluster._data_dir) / 'standby.signal').touch() await self.loop.run_in_executor( None, lambda: self.cluster.start( port=port, server_settings=self.get_server_settings(), ) ) async def test_full_reconnect_on_node_change_role(self): if self.cluster.get_pg_version() < (12, 0): self.skipTest("PostgreSQL < 12 cannot support standby.signal") return pool = await self.create_pool( min_size=1, max_size=1, target_session_attrs='primary' ) # Force a new connection to be created await pool.fetchval('SELECT 1') await self.simulate_cluster_recovery_mode() # current pool connection info cache is expired, # but we don't know it yet with self.assertRaises(asyncpg.TargetServerAttributeNotMatched) as cm: await pool.execute('SELECT 1') self.assertEqual( cm.exception.args[0], "None of the hosts match the target attribute requirement " "" ) # force reconnect with self.assertRaises(asyncpg.TargetServerAttributeNotMatched) as cm: await pool.execute('SELECT 1') self.assertEqual( cm.exception.args[0], "None of the hosts match the target attribute requirement " "" ) @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') class TestHotStandby(tb.HotStandbyTestCase): def create_pool(self, **kwargs): conn_spec = self.standby_cluster.get_connection_spec() conn_spec.update(kwargs) return pg_pool.create_pool(loop=self.loop, **conn_spec) async def test_standby_pool_01(self): for n in {1, 3, 5, 10, 20, 100}: with self.subTest(tasksnum=n): pool = await self.create_pool( database='postgres', user='postgres', min_size=5, max_size=10) async def worker(): con = await pool.acquire() self.assertEqual(await con.fetchval('SELECT 1'), 1) await pool.release(con) tasks = [worker() for _ in range(n)] await asyncio.gather(*tasks) await pool.close() async def test_standby_cursors(self): con = await self.standby_cluster.connect( database='postgres', user='postgres', loop=self.loop) try: async with con.transaction(): cursor = await con.cursor('SELECT 1') self.assertEqual(await cursor.fetchrow(), (1,)) finally: await con.close() ================================================ FILE: tests/test_prepare.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import asyncpg import gc import unittest from asyncpg import _testbase as tb from asyncpg import exceptions class TestPrepare(tb.ConnectedTestCase): async def test_prepare_01(self): self.assertEqual(self.con._protocol.queries_count, 0) st = await self.con.prepare('SELECT 1 = $1 AS test') self.assertEqual(self.con._protocol.queries_count, 0) self.assertEqual(st.get_query(), 'SELECT 1 = $1 AS test') rec = await st.fetchrow(1) self.assertEqual(self.con._protocol.queries_count, 1) self.assertTrue(rec['test']) self.assertEqual(len(rec), 1) self.assertEqual(False, await st.fetchval(10)) self.assertEqual(self.con._protocol.queries_count, 2) async def test_prepare_02(self): with self.assertRaisesRegex(Exception, 'column "a" does not exist'): await self.con.prepare('SELECT a') async def test_prepare_03(self): cases = [ ('text', ("'NULL'", 'NULL'), [ 'aaa', None ]), ('decimal', ('0', 0), [ 123, 123.5, None ]) ] for type, (none_name, none_val), vals in cases: st = await self.con.prepare(''' SELECT CASE WHEN $1::{type} IS NULL THEN {default} ELSE $1::{type} END'''.format( type=type, default=none_name)) for val in vals: with self.subTest(type=type, value=val): res = await st.fetchval(val) if val is None: self.assertEqual(res, none_val) else: self.assertEqual(res, val) async def test_prepare_04(self): s = await self.con.prepare('SELECT $1::smallint') self.assertEqual(await s.fetchval(10), 10) s = await self.con.prepare('SELECT $1::smallint * 2') self.assertEqual(await s.fetchval(10), 20) s = await self.con.prepare('SELECT generate_series(5,10)') self.assertEqual(await s.fetchval(), 5) # Since the "execute" message was sent with a limit=1, # we will receive a PortalSuspended message, instead of # CommandComplete. Which means there will be no status # message set. self.assertIsNone(s.get_statusmsg()) # Repeat the same test for 'fetchrow()'. self.assertEqual(await s.fetchrow(), (5,)) self.assertIsNone(s.get_statusmsg()) async def test_prepare_05_unknownoid(self): s = await self.con.prepare("SELECT 'test'") self.assertEqual(await s.fetchval(), 'test') async def test_prepare_06_interrupted_close(self): stmt = await self.con.prepare('''SELECT pg_sleep(10)''') fut = self.loop.create_task(stmt.fetch()) await asyncio.sleep(0.2) self.assertFalse(self.con.is_closed()) await self.con.close() self.assertTrue(self.con.is_closed()) with self.assertRaises(asyncpg.QueryCanceledError): await fut # Test that it's OK to call close again await self.con.close() async def test_prepare_07_interrupted_terminate(self): stmt = await self.con.prepare('''SELECT pg_sleep(10)''') fut = self.loop.create_task(stmt.fetchval()) await asyncio.sleep(0.2) self.assertFalse(self.con.is_closed()) self.con.terminate() self.assertTrue(self.con.is_closed()) with self.assertRaisesRegex(asyncpg.ConnectionDoesNotExistError, 'closed in the middle'): await fut # Test that it's OK to call terminate again self.con.terminate() async def test_prepare_08_big_result(self): stmt = await self.con.prepare('select generate_series(0,10000)') result = await stmt.fetch() self.assertEqual(len(result), 10001) self.assertEqual( [r[0] for r in result], list(range(10001))) async def test_prepare_09_raise_error(self): # Stress test ReadBuffer.read_cstr() msg = '0' * 1024 * 100 query = """ DO language plpgsql $$ BEGIN RAISE EXCEPTION '{}'; END $$;""".format(msg) stmt = await self.con.prepare(query) with self.assertRaisesRegex(asyncpg.RaiseError, msg): with tb.silence_asyncio_long_exec_warning(): await stmt.fetchval() async def test_prepare_10_stmt_lru(self): cache = self.con._stmt_cache query = 'select {}' cache_max = cache.get_max_size() iter_max = cache_max * 2 + 11 # First, we have no cached statements. self.assertEqual(len(cache), 0) stmts = [] for i in range(iter_max): s = await self.con._prepare(query.format(i), use_cache=True) self.assertEqual(await s.fetchval(), i) stmts.append(s) # At this point our cache should be full. self.assertEqual(len(cache), cache_max) self.assertTrue(all(not s.closed for s in cache.iter_statements())) # Since there are references to the statements (`stmts` list), # no statements are scheduled to be closed. self.assertEqual(len(self.con._stmts_to_close), 0) # Removing refs to statements and preparing a new statement # will cause connection to cleanup any stale statements. stmts.clear() gc.collect() # Now we have a bunch of statements that have no refs to them # scheduled to be closed. self.assertEqual(len(self.con._stmts_to_close), iter_max - cache_max) self.assertTrue(all(s.closed for s in self.con._stmts_to_close)) self.assertTrue(all(not s.closed for s in cache.iter_statements())) zero = await self.con.prepare(query.format(0)) # Hence, all stale statements should be closed now. self.assertEqual(len(self.con._stmts_to_close), 0) # The number of cached statements will stay the same though. self.assertEqual(len(cache), cache_max) self.assertTrue(all(not s.closed for s in cache.iter_statements())) # After closing all statements will be closed. await self.con.close() self.assertEqual(len(self.con._stmts_to_close), 0) self.assertEqual(len(cache), 0) # An attempt to perform an operation on a closed statement # will trigger an error. with self.assertRaisesRegex(asyncpg.InterfaceError, 'is closed'): await zero.fetchval() async def test_prepare_11_stmt_gc(self): # Test that prepared statements should stay in the cache after # they are GCed. cache = self.con._stmt_cache # First, we have no cached statements. self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) # The prepared statement that we'll create will be GCed # right await. However, its state should be still in # in the statements LRU cache. await self.con._prepare('select 1', use_cache=True) gc.collect() self.assertEqual(len(cache), 1) self.assertEqual(len(self.con._stmts_to_close), 0) async def test_prepare_12_stmt_gc(self): # Test that prepared statements are closed when there is no space # for them in the LRU cache and there are no references to them. cache = self.con._stmt_cache cache_max = cache.get_max_size() # First, we have no cached statements. self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) stmt = await self.con._prepare('select 100000000', use_cache=True) self.assertEqual(len(cache), 1) self.assertEqual(len(self.con._stmts_to_close), 0) for i in range(cache_max): await self.con._prepare('select {}'.format(i), use_cache=True) self.assertEqual(len(cache), cache_max) self.assertEqual(len(self.con._stmts_to_close), 0) del stmt gc.collect() self.assertEqual(len(cache), cache_max) self.assertEqual(len(self.con._stmts_to_close), 1) async def test_prepare_13_connect(self): v = await self.con.fetchval( 'SELECT $1::smallint AS foo', 10, column='foo') self.assertEqual(v, 10) r = await self.con.fetchrow('SELECT $1::smallint * 2 AS test', 10) self.assertEqual(r['test'], 20) rows = await self.con.fetch('SELECT generate_series(0,$1::int)', 3) self.assertEqual([r[0] for r in rows], [0, 1, 2, 3]) async def test_prepare_14_explain(self): # Test simple EXPLAIN. stmt = await self.con.prepare('SELECT typname FROM pg_type') plan = await stmt.explain() self.assertEqual(plan[0]['Plan']['Relation Name'], 'pg_type') # Test "EXPLAIN ANALYZE". stmt = await self.con.prepare( 'SELECT typname, typlen FROM pg_type WHERE typlen > $1') plan = await stmt.explain(2, analyze=True) self.assertEqual(plan[0]['Plan']['Relation Name'], 'pg_type') self.assertIn('Actual Total Time', plan[0]['Plan']) # Test that 'EXPLAIN ANALYZE' is executed in a transaction # that gets rollbacked. tr = self.con.transaction() await tr.start() try: await self.con.execute('CREATE TABLE mytab (a int)') stmt = await self.con.prepare( 'INSERT INTO mytab (a) VALUES (1), (2)') plan = await stmt.explain(analyze=True) self.assertEqual(plan[0]['Plan']['Operation'], 'Insert') # Check that no data was inserted res = await self.con.fetch('SELECT * FROM mytab') self.assertEqual(res, []) finally: await tr.rollback() async def test_prepare_15_stmt_gc_cache_disabled(self): # Test that even if the statements cache is off, we're still # cleaning up GCed statements. cache = self.con._stmt_cache self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) # Disable cache cache.set_max_size(0) stmt = await self.con._prepare('select 100000000', use_cache=True) self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) del stmt gc.collect() # After GC, _stmts_to_close should contain stmt's state self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 1) # Next "prepare" call will trigger a cleanup stmt = await self.con._prepare('select 1', use_cache=True) self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) del stmt async def test_prepare_16_command_result(self): async def status(query): stmt = await self.con.prepare(query) await stmt.fetch() return stmt.get_statusmsg() try: self.assertEqual( await status('CREATE TABLE mytab (a int)'), 'CREATE TABLE') self.assertEqual( await status('INSERT INTO mytab (a) VALUES (1), (2)'), 'INSERT 0 2') self.assertEqual( await status('SELECT a FROM mytab'), 'SELECT 2') self.assertEqual( await status('UPDATE mytab SET a = 3 WHERE a = 1'), 'UPDATE 1') finally: self.assertEqual( await status('DROP TABLE mytab'), 'DROP TABLE') async def test_prepare_17_stmt_closed_lru(self): st = await self.con.prepare('SELECT 1') st._state.mark_closed() with self.assertRaisesRegex(asyncpg.InterfaceError, 'is closed'): await st.fetch() st = await self.con.prepare('SELECT 1') self.assertEqual(await st.fetchval(), 1) async def test_prepare_18_empty_result(self): # test EmptyQueryResponse protocol message st = await self.con.prepare('') self.assertEqual(await st.fetch(), []) self.assertIsNone(await st.fetchval()) self.assertIsNone(await st.fetchrow()) self.assertEqual(await self.con.fetch(''), []) self.assertIsNone(await self.con.fetchval('')) self.assertIsNone(await self.con.fetchrow('')) async def test_prepare_19_concurrent_calls(self): st = self.loop.create_task(self.con.fetchval( 'SELECT ROW(pg_sleep(0.1), 1)')) # Wait for some time to make sure the first query is fully # prepared (!) and is now awaiting the results (!!). await asyncio.sleep(0.01) with self.assertRaisesRegex(asyncpg.InterfaceError, 'another operation'): await self.con.execute('SELECT 2') self.assertEqual(await st, (None, 1)) async def test_prepare_20_concurrent_calls(self): expected = ((None, 1),) for methname, val in [('fetch', [expected]), ('fetchval', expected[0]), ('fetchrow', expected)]: with self.subTest(meth=methname): meth = getattr(self.con, methname) vf = self.loop.create_task( meth('SELECT ROW(pg_sleep(0.1), 1)')) await asyncio.sleep(0.01) with self.assertRaisesRegex(asyncpg.InterfaceError, 'another operation'): await meth('SELECT 2') self.assertEqual(await vf, val) async def test_prepare_21_errors(self): stmt = await self.con.prepare('SELECT 10 / $1::int') with self.assertRaises(asyncpg.DivisionByZeroError): await stmt.fetchval(0) self.assertEqual(await stmt.fetchval(5), 2) async def test_prepare_22_empty(self): # Support for empty target list was added in PostgreSQL 9.4 if self.server_version < (9, 4): raise unittest.SkipTest( 'PostgreSQL servers < 9.4 do not support empty target list.') result = await self.con.fetchrow('SELECT') self.assertEqual(result, ()) self.assertEqual(repr(result), '') async def test_prepare_statement_invalid(self): await self.con.execute('CREATE TABLE tab1(a int, b int)') try: await self.con.execute('INSERT INTO tab1 VALUES (1, 2)') stmt = await self.con.prepare('SELECT * FROM tab1') await self.con.execute( 'ALTER TABLE tab1 ALTER COLUMN b SET DATA TYPE text') with self.assertRaisesRegex(asyncpg.InvalidCachedStatementError, 'cached statement plan is invalid'): await stmt.fetchrow() finally: await self.con.execute('DROP TABLE tab1') @tb.with_connection_options(statement_cache_size=0) async def test_prepare_23_no_stmt_cache_seq(self): self.assertEqual(self.con._stmt_cache.get_max_size(), 0) async def check_simple(): # Run a simple query a few times. self.assertEqual(await self.con.fetchval('SELECT 1'), 1) self.assertEqual(await self.con.fetchval('SELECT 2'), 2) self.assertEqual(await self.con.fetchval('SELECT 1'), 1) await check_simple() # Run a query that timeouts. with self.assertRaises(asyncio.TimeoutError): await self.con.fetchrow('select pg_sleep(10)', timeout=0.02) # Check that we can run new queries after a timeout. await check_simple() # Try a cursor/timeout combination. Cursors should always use # named prepared statements. async with self.con.transaction(): with self.assertRaises(asyncio.TimeoutError): async for _ in self.con.cursor( # NOQA 'select pg_sleep(10)', timeout=0.1): pass # Check that we can run queries after a failed cursor # operation. await check_simple() @tb.with_connection_options(max_cached_statement_lifetime=142) async def test_prepare_24_max_lifetime(self): cache = self.con._stmt_cache self.assertEqual(cache.get_max_lifetime(), 142) cache.set_max_lifetime(1) s = await self.con._prepare('SELECT 1', use_cache=True) state = s._state s = await self.con._prepare('SELECT 1', use_cache=True) self.assertIs(s._state, state) s = await self.con._prepare('SELECT 1', use_cache=True) self.assertIs(s._state, state) await asyncio.sleep(1) s = await self.con._prepare('SELECT 1', use_cache=True) self.assertIsNot(s._state, state) @tb.with_connection_options(max_cached_statement_lifetime=0.5) async def test_prepare_25_max_lifetime_reset(self): cache = self.con._stmt_cache s = await self.con._prepare('SELECT 1', use_cache=True) state = s._state # Disable max_lifetime cache.set_max_lifetime(0) await asyncio.sleep(1) # The statement should still be cached (as we disabled the timeout). s = await self.con._prepare('SELECT 1', use_cache=True) self.assertIs(s._state, state) @tb.with_connection_options(max_cached_statement_lifetime=0.5) async def test_prepare_26_max_lifetime_max_size(self): cache = self.con._stmt_cache s = await self.con._prepare('SELECT 1', use_cache=True) state = s._state # Disable max_lifetime cache.set_max_size(0) s = await self.con._prepare('SELECT 1', use_cache=True) self.assertIsNot(s._state, state) # Check that nothing crashes after the initial timeout await asyncio.sleep(1) @tb.with_connection_options(max_cacheable_statement_size=50) async def test_prepare_27_max_cacheable_statement_size(self): cache = self.con._stmt_cache await self.con._prepare('SELECT 1', use_cache=True) self.assertEqual(len(cache), 1) # Test that long and explicitly created prepared statements # are not cached. await self.con._prepare("SELECT \'" + "a" * 50 + "\'", use_cache=True) self.assertEqual(len(cache), 1) # Test that implicitly created long prepared statements # are not cached. await self.con.fetchval("SELECT \'" + "a" * 50 + "\'") self.assertEqual(len(cache), 1) # Test that short prepared statements can still be cached. await self.con._prepare('SELECT 2', use_cache=True) self.assertEqual(len(cache), 2) async def test_prepare_28_max_args(self): N = 32768 args = ','.join('${}'.format(i) for i in range(1, N + 1)) query = 'SELECT ARRAY[{}]'.format(args) with self.assertRaisesRegex( exceptions.InterfaceError, 'the number of query arguments cannot exceed 32767'): await self.con.fetchval(query, *range(1, N + 1)) async def test_prepare_29_duplicates(self): # In addition to test_record.py, let's have a full functional # test for records with duplicate keys. r = await self.con.fetchrow('SELECT 1 as a, 2 as b, 3 as a') self.assertEqual(list(r.items()), [('a', 1), ('b', 2), ('a', 3)]) self.assertEqual(list(r.keys()), ['a', 'b', 'a']) self.assertEqual(list(r.values()), [1, 2, 3]) self.assertEqual(r['a'], 3) self.assertEqual(r['b'], 2) self.assertEqual(r[0], 1) self.assertEqual(r[1], 2) self.assertEqual(r[2], 3) async def test_prepare_30_invalid_arg_count(self): with self.assertRaisesRegex( exceptions.InterfaceError, 'the server expects 1 argument for this query, 0 were passed'): await self.con.fetchval('SELECT $1::int') with self.assertRaisesRegex( exceptions.InterfaceError, 'the server expects 0 arguments for this query, 1 was passed'): await self.con.fetchval('SELECT 1', 1) async def test_prepare_31_pgbouncer_note(self): try: await self.con.execute(""" DO $$ BEGIN RAISE EXCEPTION 'duplicate statement' USING ERRCODE = '42P05'; END; $$ LANGUAGE plpgsql; """) except asyncpg.DuplicatePreparedStatementError as e: self.assertTrue('pgbouncer' in e.hint) else: self.fail('DuplicatePreparedStatementError not raised') try: await self.con.execute(""" DO $$ BEGIN RAISE EXCEPTION 'invalid statement' USING ERRCODE = '26000'; END; $$ LANGUAGE plpgsql; """) except asyncpg.InvalidSQLStatementNameError as e: self.assertTrue('pgbouncer' in e.hint) else: self.fail('InvalidSQLStatementNameError not raised') async def test_prepare_does_not_use_cache(self): cache = self.con._stmt_cache # prepare with disabled cache await self.con.prepare('select 1') self.assertEqual(len(cache), 0) async def test_prepare_explicitly_named(self): ps = await self.con.prepare('select 1', name='foobar') self.assertEqual(ps.get_name(), 'foobar') self.assertEqual(await self.con.fetchval('EXECUTE foobar'), 1) with self.assertRaisesRegex( exceptions.DuplicatePreparedStatementError, 'prepared statement "foobar" already exists', ): await self.con.prepare('select 1', name='foobar') async def test_prepare_fetchmany(self): tr = self.con.transaction() await tr.start() try: await self.con.execute('CREATE TABLE fetchmany (a int, b text)') stmt = await self.con.prepare( 'INSERT INTO fetchmany (a, b) VALUES ($1, $2) RETURNING a, b' ) result = await stmt.fetchmany([(1, 'a'), (2, 'b'), (3, 'c')]) self.assertEqual(result, [(1, 'a'), (2, 'b'), (3, 'c')]) finally: await tr.rollback() ================================================ FILE: tests/test_record.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import contextlib import collections import gc import pickle import sys import asyncpg from asyncpg import _testbase as tb from asyncpg.protocol.protocol import _create_record as Record R_A = collections.OrderedDict([('a', 0)]) R_AB = collections.OrderedDict([('a', 0), ('b', 1)]) R_AC = collections.OrderedDict([('a', 0), ('c', 1)]) R_ABC = collections.OrderedDict([('a', 0), ('b', 1), ('c', 2)]) class CustomRecord(asyncpg.Record): pass class AnotherCustomRecord(asyncpg.Record): pass class TestRecord(tb.ConnectedTestCase): @contextlib.contextmanager def checkref(self, *objs): cnt = [sys.getrefcount(objs[i]) for i in range(len(objs))] yield for _ in range(3): gc.collect() for i in range(len(objs)): before = cnt[i] after = sys.getrefcount(objs[i]) if before != after: self.fail('refcounts differ for {!r}: {:+}'.format( objs[i], after - before)) def test_record_gc(self): elem = object() mapping = {} with self.checkref(mapping, elem): r = Record(mapping, (elem,)) del r key = 'spam' val = int('101010') mapping = {key: val} with self.checkref(key, val): r = Record(mapping, (0,)) with self.assertRaises(RuntimeError): r[key] del r key = 'spam' val = 'ham' mapping = {key: val} with self.checkref(key, val): r = Record(mapping, (0,)) with self.assertRaises(RuntimeError): r[key] del r def test_record_freelist_ok(self): for _ in range(10000): Record(R_A, (42,)) Record(R_AB, (42, 42,)) def test_record_len_getindex(self): r = Record(R_A, (42,)) self.assertEqual(len(r), 1) self.assertEqual(r[0], 42) self.assertEqual(r['a'], 42) r = Record(R_AB, (42, 43)) self.assertEqual(len(r), 2) self.assertEqual(r[0], 42) self.assertEqual(r[1], 43) self.assertEqual(r['a'], 42) self.assertEqual(r['b'], 43) with self.assertRaisesRegex(IndexError, 'record index out of range'): r[1000] with self.assertRaisesRegex(KeyError, 'spam'): r['spam'] with self.assertRaisesRegex(KeyError, 'spam'): Record(None, (1,))['spam'] with self.assertRaisesRegex(RuntimeError, 'invalid record descriptor'): Record({'spam': 123}, (1,))['spam'] def test_record_slice(self): r = Record(R_ABC, (1, 2, 3)) self.assertEqual(r[:], (1, 2, 3)) self.assertEqual(r[:1], (1,)) self.assertEqual(r[::-1], (3, 2, 1)) self.assertEqual(r[::-2], (3, 1)) self.assertEqual(r[1:2], (2,)) self.assertEqual(r[2:2], ()) def test_record_immutable(self): r = Record(R_A, (42,)) with self.assertRaisesRegex(TypeError, 'does not support item'): r[0] = 1 def test_record_repr(self): self.assertEqual( repr(Record(R_A, (42,))), '') self.assertEqual( repr(Record(R_AB, (42, -1))), '') # test invalid records just in case with self.assertRaisesRegex(RuntimeError, 'invalid .* mapping'): repr(Record(R_A, (42, 43))) self.assertEqual(repr(Record(R_AB, (42,))), '') class Key: def __str__(self): 1 / 0 def __repr__(self): 1 / 0 with self.assertRaises(ZeroDivisionError): repr(Record({Key(): 0}, (42,))) with self.assertRaises(ZeroDivisionError): repr(Record(R_A, (Key(),))) def test_record_iter(self): r = Record(R_AB, (42, 43)) with self.checkref(r): self.assertEqual(iter(r).__length_hint__(), 2) self.assertEqual(tuple(r), (42, 43)) def test_record_values(self): r = Record(R_AB, (42, 43)) vv = r.values() self.assertEqual(tuple(vv), (42, 43)) self.assertTrue( repr(vv).startswith('') self.assertEqual(str(r), '') with self.assertRaisesRegex(KeyError, 'aaa'): r['aaa'] self.assertEqual(dict(r.items()), {}) self.assertEqual(list(r.keys()), []) self.assertEqual(list(r.values()), []) async def test_record_duplicate_colnames(self): """Test that Record handles duplicate column names.""" records_descs = [ [('a', 1)], [('a', 1), ('a', 2)], [('a', 1), ('b', 2), ('a', 3)], [('a', 1), ('b', 2), ('a', 3), ('c', 4), ('b', 5)], ] for desc in records_descs: items = collections.OrderedDict(desc) query = 'SELECT ' + ', '.join( ['{} as {}'.format(p[1], p[0]) for p in desc]) with self.subTest(query=query): r = await self.con.fetchrow(query) for idx, (field, val) in enumerate(desc): self.assertEqual(r[idx], val) self.assertEqual(r[field], items[field]) expected_repr = ''.format( ' '.join('{}={}'.format(p[0], p[1]) for p in desc)) self.assertEqual(repr(r), expected_repr) self.assertEqual(list(r.items()), desc) self.assertEqual(list(r.values()), [p[1] for p in desc]) self.assertEqual(list(r.keys()), [p[0] for p in desc]) async def test_record_isinstance(self): """Test that Record works with isinstance.""" r = await self.con.fetchrow('SELECT 1 as a, 2 as b') self.assertTrue(isinstance(r, asyncpg.Record)) async def test_record_no_new(self): """Instances of Record cannot be directly created.""" with self.assertRaisesRegex( TypeError, "cannot create 'asyncpg.protocol.record.Record' instances", ): asyncpg.Record() @tb.with_connection_options(record_class=CustomRecord) async def test_record_subclass_01(self): r = await self.con.fetchrow("SELECT 1 as a, '2' as b") self.assertIsInstance(r, CustomRecord) r = await self.con.fetch("SELECT 1 as a, '2' as b") self.assertIsInstance(r[0], CustomRecord) async with self.con.transaction(): cur = await self.con.cursor("SELECT 1 as a, '2' as b") r = await cur.fetchrow() self.assertIsInstance(r, CustomRecord) cur = await self.con.cursor("SELECT 1 as a, '2' as b") r = await cur.fetch(1) self.assertIsInstance(r[0], CustomRecord) async with self.con.transaction(): cur = self.con.cursor("SELECT 1 as a, '2' as b") async for r in cur: self.assertIsInstance(r, CustomRecord) ps = await self.con.prepare("SELECT 1 as a, '2' as b") r = await ps.fetchrow() self.assertIsInstance(r, CustomRecord) async def test_record_subclass_02(self): r = await self.con.fetchrow( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) self.assertIsInstance(r, CustomRecord) r = await self.con.fetch( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) self.assertIsInstance(r[0], CustomRecord) async with self.con.transaction(): cur = await self.con.cursor( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) r = await cur.fetchrow() self.assertIsInstance(r, CustomRecord) cur = await self.con.cursor( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) r = await cur.fetch(1) self.assertIsInstance(r[0], CustomRecord) async with self.con.transaction(): cur = self.con.cursor( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) async for r in cur: self.assertIsInstance(r, CustomRecord) ps = await self.con.prepare( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) r = await ps.fetchrow() self.assertIsInstance(r, CustomRecord) r = await ps.fetch() self.assertIsInstance(r[0], CustomRecord) @tb.with_connection_options(record_class=AnotherCustomRecord) async def test_record_subclass_03(self): r = await self.con.fetchrow( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) self.assertIsInstance(r, CustomRecord) r = await self.con.fetch( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) self.assertIsInstance(r[0], CustomRecord) async with self.con.transaction(): cur = await self.con.cursor( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) r = await cur.fetchrow() self.assertIsInstance(r, CustomRecord) cur = await self.con.cursor( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) r = await cur.fetch(1) self.assertIsInstance(r[0], CustomRecord) async with self.con.transaction(): cur = self.con.cursor( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) async for r in cur: self.assertIsInstance(r, CustomRecord) ps = await self.con.prepare( "SELECT 1 as a, '2' as b", record_class=CustomRecord, ) r = await ps.fetchrow() self.assertIsInstance(r, CustomRecord) r = await ps.fetch() self.assertIsInstance(r[0], CustomRecord) @tb.with_connection_options(record_class=CustomRecord) async def test_record_subclass_04(self): r = await self.con.fetchrow( "SELECT 1 as a, '2' as b", record_class=asyncpg.Record, ) self.assertIs(type(r), asyncpg.Record) r = await self.con.fetch( "SELECT 1 as a, '2' as b", record_class=asyncpg.Record, ) self.assertIs(type(r[0]), asyncpg.Record) async with self.con.transaction(): cur = await self.con.cursor( "SELECT 1 as a, '2' as b", record_class=asyncpg.Record, ) r = await cur.fetchrow() self.assertIs(type(r), asyncpg.Record) cur = await self.con.cursor( "SELECT 1 as a, '2' as b", record_class=asyncpg.Record, ) r = await cur.fetch(1) self.assertIs(type(r[0]), asyncpg.Record) async with self.con.transaction(): cur = self.con.cursor( "SELECT 1 as a, '2' as b", record_class=asyncpg.Record, ) async for r in cur: self.assertIs(type(r), asyncpg.Record) ps = await self.con.prepare( "SELECT 1 as a, '2' as b", record_class=asyncpg.Record, ) r = await ps.fetchrow() self.assertIs(type(r), asyncpg.Record) r = await ps.fetch() self.assertIs(type(r[0]), asyncpg.Record) async def test_record_subclass_05(self): class MyRecord(asyncpg.Record): pass r = await self.con.fetchrow( "SELECT 1 as a, '2' as b", record_class=MyRecord, ) self.assertIsInstance(r, MyRecord) self.assertEqual(repr(r), "") self.assertEqual(list(r.items()), [('a', 1), ('b', '2')]) self.assertEqual(list(r.keys()), ['a', 'b']) self.assertEqual(list(r.values()), [1, '2']) self.assertIn('b', r) self.assertEqual(next(iter(r)), 1) async def test_record_subclass_06(self): class MyRecord(asyncpg.Record): def __init__(self): raise AssertionError('this is not supposed to be called') class MyRecord2(asyncpg.Record): def __new__(cls): raise AssertionError('this is not supposed to be called') class MyRecordBad: pass with self.assertRaisesRegex( asyncpg.InterfaceError, 'record_class must not redefine __new__ or __init__', ): await self.con.fetchrow( "SELECT 1 as a, '2' as b", record_class=MyRecord, ) with self.assertRaisesRegex( asyncpg.InterfaceError, 'record_class must not redefine __new__ or __init__', ): await self.con.fetchrow( "SELECT 1 as a, '2' as b", record_class=MyRecord2, ) with self.assertRaisesRegex( asyncpg.InterfaceError, 'record_class is expected to be a subclass of asyncpg.Record', ): await self.con.fetchrow( "SELECT 1 as a, '2' as b", record_class=MyRecordBad, ) with self.assertRaisesRegex( asyncpg.InterfaceError, 'record_class is expected to be a subclass of asyncpg.Record', ): await self.connect(record_class=MyRecordBad) ================================================ FILE: tests/test_subinterpreters.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import textwrap import threading import unittest try: from concurrent import interpreters except ImportError: pass else: class TestSubinterpreters(unittest.TestCase): def test_record_module_loads_in_subinterpreter(self) -> None: def run_in_subinterpreter() -> None: interp = interpreters.create() try: code = textwrap.dedent("""\ import asyncpg.protocol.record as record assert record.Record is not None """) interp.exec(code) finally: interp.close() thread = threading.Thread(target=run_in_subinterpreter) thread.start() thread.join() def test_record_module_state_isolation(self) -> None: import asyncpg.protocol.record main_record_id = id(asyncpg.protocol.record.Record) def run_in_subinterpreter() -> None: interp = interpreters.create() try: code = textwrap.dedent(f"""\ import asyncpg.protocol.record as record sub_record_id = id(record.Record) main_id = {main_record_id} assert sub_record_id != main_id, ( f"Record type objects are the same: " f"{{sub_record_id}} == {{main_id}}. " f"This indicates shared global state." ) """) interp.exec(code) finally: interp.close() thread = threading.Thread(target=run_in_subinterpreter) thread.start() thread.join() if __name__ == "__main__": _ = unittest.main() ================================================ FILE: tests/test_test.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import types import unittest from asyncpg import _testbase as tb class BaseSimpleTestCase: async def test_tests_zero_error(self): await asyncio.sleep(0.01) 1 / 0 class TestTests(unittest.TestCase): def test_tests_fail_1(self): SimpleTestCase = types.new_class('SimpleTestCase', (BaseSimpleTestCase, tb.TestCase)) suite = unittest.TestSuite() suite.addTest(SimpleTestCase('test_tests_zero_error')) result = unittest.TestResult() suite.run(result) self.assertIn('ZeroDivisionError', result.errors[0][1]) class TestHelpers(tb.TestCase): async def test_tests_assertLoopErrorHandlerCalled_01(self): with self.assertRaisesRegex(AssertionError, r'no message.*was logged'): with self.assertLoopErrorHandlerCalled('aa'): self.loop.call_exception_handler({'message': 'bb a bb'}) with self.assertLoopErrorHandlerCalled('aa'): self.loop.call_exception_handler({'message': 'bbaabb'}) ================================================ FILE: tests/test_timeout.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncio import asyncpg from asyncpg import connection as pg_connection from asyncpg import _testbase as tb MAX_RUNTIME = 0.5 class TestTimeout(tb.ConnectedTestCase): async def test_timeout_01(self): for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}: with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): meth = getattr(self.con, methname) await meth('select pg_sleep(10)', timeout=0.02) self.assertEqual(await self.con.fetch('select 1'), [(1,)]) async def test_timeout_02(self): st = await self.con.prepare('select pg_sleep(10)') for methname in {'fetch', 'fetchrow', 'fetchval'}: with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): meth = getattr(st, methname) await meth(timeout=0.02) self.assertEqual(await self.con.fetch('select 1'), [(1,)]) async def test_timeout_03(self): task = self.loop.create_task( self.con.fetch('select pg_sleep(10)', timeout=0.2)) await asyncio.sleep(0.05) task.cancel() with self.assertRaises(asyncio.CancelledError), \ self.assertRunUnder(MAX_RUNTIME): await task self.assertEqual(await self.con.fetch('select 1'), [(1,)]) async def test_timeout_04(self): st = await self.con.prepare('select pg_sleep(10)', timeout=0.1) with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): async with self.con.transaction(): async for _ in st.cursor(timeout=0.1): # NOQA pass self.assertEqual(await self.con.fetch('select 1'), [(1,)]) st = await self.con.prepare('select pg_sleep(10)', timeout=0.1) async with self.con.transaction(): cur = await st.cursor() with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): await cur.fetch(1, timeout=0.1) self.assertEqual(await self.con.fetch('select 1'), [(1,)]) async def test_timeout_05(self): # Stress-test timeouts - try to trigger a race condition # between a cancellation request to Postgres and next # query (SELECT 1) for _ in range(500): with self.assertRaises(asyncio.TimeoutError): await self.con.fetch('SELECT pg_sleep(1)', timeout=1e-10) self.assertEqual(await self.con.fetch('SELECT 1'), [(1,)]) async def test_timeout_06(self): async with self.con.transaction(): with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): async for _ in self.con.cursor( # NOQA 'select pg_sleep(10)', timeout=0.1): pass self.assertEqual(await self.con.fetch('select 1'), [(1,)]) async with self.con.transaction(): cur = await self.con.cursor('select pg_sleep(10)') with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): await cur.fetch(1, timeout=0.1) async with self.con.transaction(): cur = await self.con.cursor('select pg_sleep(10)') with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): await cur.forward(1, timeout=1e-10) async with self.con.transaction(): cur = await self.con.cursor('select pg_sleep(10)') with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): await cur.fetchrow(timeout=0.1) async with self.con.transaction(): cur = await self.con.cursor('select pg_sleep(10)') with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): await cur.fetchrow(timeout=0.1) with self.assertRaises(asyncpg.InFailedSQLTransactionError): await cur.fetch(1) self.assertEqual(await self.con.fetch('select 1'), [(1,)]) async def test_invalid_timeout(self): for command_timeout in ('a', False, -1): with self.subTest(command_timeout=command_timeout): with self.assertRaisesRegex(ValueError, 'invalid command_timeout'): await self.connect(command_timeout=command_timeout) # Note: negative timeouts are OK for method calls. for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}: for timeout in ('a', False): with self.subTest(timeout=timeout): with self.assertRaisesRegex(ValueError, 'invalid timeout'): await self.con.execute('SELECT 1', timeout=timeout) class TestConnectionCommandTimeout(tb.ConnectedTestCase): @tb.with_connection_options(command_timeout=0.2) async def test_command_timeout_01(self): for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}: with self.assertRaises(asyncio.TimeoutError), \ self.assertRunUnder(MAX_RUNTIME): meth = getattr(self.con, methname) await meth('select pg_sleep(10)') self.assertEqual(await self.con.fetch('select 1'), [(1,)]) class SlowPrepareConnection(pg_connection.Connection): """Connection class to test timeouts.""" async def _get_statement(self, query, timeout, **kwargs): await asyncio.sleep(0.3) return await super()._get_statement(query, timeout, **kwargs) class TestTimeoutCoversPrepare(tb.ConnectedTestCase): @tb.with_connection_options(connection_class=SlowPrepareConnection, command_timeout=0.3) async def test_timeout_covers_prepare_01(self): for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}: with self.assertRaises(asyncio.TimeoutError): meth = getattr(self.con, methname) await meth('select pg_sleep($1)', 0.2) ================================================ FILE: tests/test_transaction.py ================================================ # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import asyncpg from asyncpg import _testbase as tb class TestTransaction(tb.ConnectedTestCase): async def test_transaction_regular(self): self.assertIsNone(self.con._top_xact) self.assertFalse(self.con.is_in_transaction()) tr = self.con.transaction() self.assertIsNone(self.con._top_xact) self.assertFalse(self.con.is_in_transaction()) with self.assertRaises(ZeroDivisionError): async with tr as with_tr: self.assertIs(self.con._top_xact, tr) self.assertTrue(self.con.is_in_transaction()) # We don't return the transaction object from __aenter__, # to make it harder for people to use '.rollback()' and # '.commit()' from within an 'async with' block. self.assertIsNone(with_tr) await self.con.execute(''' CREATE TABLE mytab (a int); ''') 1 / 0 self.assertIsNone(self.con._top_xact) self.assertFalse(self.con.is_in_transaction()) with self.assertRaisesRegex(asyncpg.PostgresError, '"mytab" does not exist'): await self.con.prepare(''' SELECT * FROM mytab ''') async def test_transaction_nested(self): self.assertIsNone(self.con._top_xact) self.assertFalse(self.con.is_in_transaction()) tr = self.con.transaction() self.assertIsNone(self.con._top_xact) self.assertFalse(self.con.is_in_transaction()) with self.assertRaises(ZeroDivisionError): async with tr: self.assertIs(self.con._top_xact, tr) self.assertTrue(self.con.is_in_transaction()) await self.con.execute(''' CREATE TABLE mytab (a int); ''') async with self.con.transaction(): self.assertIs(self.con._top_xact, tr) self.assertTrue(self.con.is_in_transaction()) await self.con.execute(''' INSERT INTO mytab (a) VALUES (1), (2); ''') self.assertIs(self.con._top_xact, tr) self.assertTrue(self.con.is_in_transaction()) with self.assertRaises(ZeroDivisionError): in_tr = self.con.transaction() async with in_tr: self.assertIs(self.con._top_xact, tr) self.assertTrue(self.con.is_in_transaction()) await self.con.execute(''' INSERT INTO mytab (a) VALUES (3), (4); ''') 1 / 0 st = await self.con.prepare('SELECT * FROM mytab;') recs = [] async for rec in st.cursor(): recs.append(rec) self.assertEqual(len(recs), 2) self.assertEqual(recs[0][0], 1) self.assertEqual(recs[1][0], 2) self.assertIs(self.con._top_xact, tr) self.assertTrue(self.con.is_in_transaction()) 1 / 0 self.assertIs(self.con._top_xact, None) self.assertFalse(self.con.is_in_transaction()) with self.assertRaisesRegex(asyncpg.PostgresError, '"mytab" does not exist'): await self.con.prepare(''' SELECT * FROM mytab ''') async def test_transaction_interface_errors(self): self.assertIsNone(self.con._top_xact) self.assertFalse(self.con.is_in_transaction()) tr = self.con.transaction(readonly=True, isolation='serializable') with self.assertRaisesRegex(asyncpg.InterfaceError, 'cannot start; .* already started'): async with tr: await tr.start() self.assertTrue(repr(tr).startswith( ' # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from itertools import product from asyncpg.types import Range from asyncpg import _testbase as tb class TestTypes(tb.TestCase): def test_range_issubset(self): subs = [ Range(empty=True), Range(lower=1, upper=5, lower_inc=True, upper_inc=False), Range(lower=1, upper=5, lower_inc=True, upper_inc=True), Range(lower=1, upper=5, lower_inc=False, upper_inc=True), Range(lower=1, upper=5, lower_inc=False, upper_inc=False), Range(lower=-5, upper=10), Range(lower=2, upper=3), Range(lower=1, upper=None), Range(lower=None, upper=None) ] sups = [ Range(empty=True), Range(lower=1, upper=5, lower_inc=True, upper_inc=False), Range(lower=1, upper=5, lower_inc=True, upper_inc=True), Range(lower=1, upper=5, lower_inc=False, upper_inc=True), Range(lower=1, upper=5, lower_inc=False, upper_inc=False), Range(lower=None, upper=None) ] # Each row is 1 subs with all sups results = [ True, True, True, True, True, True, False, True, True, False, False, True, False, False, True, False, False, True, False, False, True, True, False, True, False, True, True, True, True, True, False, False, False, False, False, True, False, True, True, True, True, True, False, False, False, False, False, True, False, False, False, False, False, True ] for (sub, sup), res in zip(product(subs, sups), results): self.assertIs( sub.issubset(sup), res, "Sub:{}, Sup:{}".format(sub, sup) ) self.assertIs( sup.issuperset(sub), res, "Sub:{}, Sup:{}".format(sub, sup) ) ================================================ FILE: tests/test_utils.py ================================================ # Copyright (C) 2016-present the ayncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import datetime from asyncpg import utils from asyncpg import _testbase as tb class TestUtils(tb.ConnectedTestCase): async def test_mogrify_simple(self): cases = [ ('timestamp', datetime.datetime(2016, 10, 10), "SELECT '2016-10-10 00:00:00'::timestamp"), ('int[]', [[1, 2], [3, 4]], "SELECT '{{1,2},{3,4}}'::int[]"), ] for typename, data, expected in cases: with self.subTest(value=data, type=typename): mogrified = await utils._mogrify( self.con, 'SELECT $1::{}'.format(typename), [data]) self.assertEqual(mogrified, expected) async def test_mogrify_multiple(self): mogrified = await utils._mogrify( self.con, 'SELECT $1::int, $2::int[]', [1, [2, 3, 4, 5]]) expected = "SELECT '1'::int, '{2,3,4,5}'::int[]" self.assertEqual(mogrified, expected) ================================================ FILE: tools/generate_exceptions.py ================================================ #!/usr/bin/env python3 # # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import argparse import builtins import re import string import textwrap from asyncpg.exceptions import _base as apg_exc _namemap = { '08001': 'ClientCannotConnectError', '08004': 'ConnectionRejectionError', '08006': 'ConnectionFailureError', '38002': 'ModifyingExternalRoutineSQLDataNotPermittedError', '38003': 'ProhibitedExternalRoutineSQLStatementAttemptedError', '38004': 'ReadingExternalRoutineSQLDataNotPermittedError', '39004': 'NullValueInExternalRoutineNotAllowedError', '42000': 'SyntaxOrAccessError', 'XX000': 'InternalServerError', } _subclassmap = { # Special subclass of FeatureNotSupportedError # raised by Postgres in RevalidateCachedQuery. '0A000': ['InvalidCachedStatementError'] } def _get_error_name(sqlstatename, msgtype, sqlstate): if sqlstate in _namemap: return _namemap[sqlstate] parts = string.capwords(sqlstatename.replace('_', ' ')).split(' ') if parts[-1] in {'Exception', 'Failure'}: parts[-1] = 'Error' if parts[-1] != 'Error' and msgtype != 'W': parts.append('Error') for i, part in enumerate(parts): if part == 'Fdw': parts[i] = 'FDW' elif part == 'Io': parts[i] = 'IO' elif part == 'Plpgsql': parts[i] = 'PLPGSQL' elif part == 'Sql': parts[i] = 'SQL' errname = ''.join(parts) if hasattr(builtins, errname): errname = 'Postgres' + errname return errname def main(): parser = argparse.ArgumentParser( description='generate _exceptions.py from postgres/errcodes.txt') parser.add_argument('errcodesfile', type=str, help='path to errcodes.txt in PostgreSQL source') args = parser.parse_args() with open(args.errcodesfile, 'r') as errcodes_f: errcodes = errcodes_f.read() section_re = re.compile(r'^Section: .*') tpl = """\ class {clsname}({base}): {docstring}{sqlstate}""" new_section = True section_class = None buf = '# GENERATED FROM postgresql/src/backend/utils/errcodes.txt\n' + \ '# DO NOT MODIFY, use tools/generate_exceptions.py to update\n\n' + \ 'from ._base import * # NOQA\nfrom . import _base\n\n\n' classes = [] clsnames = set() def _add_class(clsname, base, sqlstate, docstring): if sqlstate: sqlstate = "sqlstate = '{}'".format(sqlstate) else: sqlstate = '' txt = tpl.format(clsname=clsname, base=base, sqlstate=sqlstate, docstring=docstring) if not sqlstate and not docstring: txt += 'pass' if len(txt.splitlines()[0]) > 79: txt = txt.replace('(', '(\n ', 1) classes.append(txt) clsnames.add(clsname) for line in errcodes.splitlines(): if not line.strip() or line.startswith('#'): continue if section_re.match(line): new_section = True continue parts = re.split(r'\s+', line) if len(parts) < 4: continue sqlstate = parts[0] msgtype = parts[1] name = parts[3] clsname = _get_error_name(name, msgtype, sqlstate) if clsname in {'SuccessfulCompletionError'}: continue if clsname in clsnames: raise ValueError( 'duplicate exception class name: {}'.format(clsname)) if new_section: section_class = clsname if clsname == 'PostgresWarning': base = '_base.PostgresLogMessage, Warning' else: if msgtype == 'W': base = 'PostgresWarning' else: base = '_base.PostgresError' new_section = False else: base = section_class existing = apg_exc.PostgresMessageMeta.get_message_class_for_sqlstate( sqlstate) if (existing and existing is not apg_exc.UnknownPostgresError and existing.__doc__): docstring = '"""{}"""\n\n '.format(existing.__doc__) else: docstring = '' _add_class(clsname=clsname, base=base, sqlstate=sqlstate, docstring=docstring) subclasses = _subclassmap.get(sqlstate, []) for subclass in subclasses: existing = getattr(apg_exc, subclass, None) if existing and existing.__doc__: docstring = '"""{}"""\n\n '.format(existing.__doc__) else: docstring = '' _add_class(clsname=subclass, base=clsname, sqlstate=None, docstring=docstring) buf += '\n\n\n'.join(classes) _all = textwrap.wrap(', '.join('{!r}'.format(c) for c in sorted(clsnames))) buf += '\n\n\n__all__ = (\n {}\n)'.format( '\n '.join(_all)) buf += '\n\n__all__ += _base.__all__' print(buf) if __name__ == '__main__': main() ================================================ FILE: tools/generate_type_map.py ================================================ #!/usr/bin/env python3 # # Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 import argparse import asyncio import asyncpg # Array types with builtin codecs, necessary for codec # bootstrap to work # _BUILTIN_ARRAYS = ('_text', '_oid') _INVALIDOID = 0 # postgresql/src/include/access/transam.h: FirstBootstrapObjectId _MAXBUILTINOID = 10000 - 1 # A list of alternative names for builtin types. _TYPE_ALIASES = { 'smallint': 'int2', 'int': 'int4', 'integer': 'int4', 'bigint': 'int8', 'decimal': 'numeric', 'real': 'float4', 'double precision': 'float8', 'timestamp with timezone': 'timestamptz', 'timestamp without timezone': 'timestamp', 'time with timezone': 'timetz', 'time without timezone': 'time', 'char': 'bpchar', 'character': 'bpchar', 'character varying': 'varchar', 'bit varying': 'varbit' } async def runner(args): conn = await asyncpg.connect(host=args.pghost, port=args.pgport, user=args.pguser) buf = ( '# Copyright (C) 2016-present the asyncpg authors and contributors\n' '# \n' '#\n' '# This module is part of asyncpg and is released under\n' '# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0' '\n\n\n' '# GENERATED FROM pg_catalog.pg_type\n' + '# DO NOT MODIFY, use tools/generate_type_map.py to update\n\n' + 'DEF INVALIDOID = {}\n'.format(_INVALIDOID) + 'DEF MAXBUILTINOID = {}\n'.format(_MAXBUILTINOID) ) pg_types = await conn.fetch(''' SELECT oid, typname FROM pg_catalog.pg_type WHERE typtype IN ('b', 'p') AND (typelem = 0 OR typname = any($1) OR typlen > 0) AND oid <= $2 ORDER BY oid ''', _BUILTIN_ARRAYS, _MAXBUILTINOID) defs = [] typemap = {} array_types = [] for pg_type in pg_types: typeoid = pg_type['oid'] typename = pg_type['typname'] defname = '{}OID'.format(typename.upper()) defs.append('DEF {name} = {oid}'.format(name=defname, oid=typeoid)) if typename in _BUILTIN_ARRAYS: array_types.append(defname) typename = typename[1:] + '[]' typemap[defname] = typename buf += 'DEF MAXSUPPORTEDOID = {}\n\n'.format(pg_types[-1]['oid']) buf += '\n'.join(defs) buf += '\n\ncdef ARRAY_TYPES = ({},)'.format(', '.join(array_types)) f_typemap = ('{}: {!r}'.format(dn, n) for dn, n in sorted(typemap.items())) buf += '\n\nBUILTIN_TYPE_OID_MAP = {{\n {}\n}}'.format( ',\n '.join(f_typemap)) buf += ('\n\nBUILTIN_TYPE_NAME_MAP = ' + '{v: k for k, v in BUILTIN_TYPE_OID_MAP.items()}') for k, v in _TYPE_ALIASES.items(): buf += ('\n\nBUILTIN_TYPE_NAME_MAP[{!r}] = \\\n ' 'BUILTIN_TYPE_NAME_MAP[{!r}]'.format(k, v)) print(buf) def main(): parser = argparse.ArgumentParser( description='generate protocol/pgtypes.pxi from pg_catalog.pg_types') parser.add_argument( '--pghost', type=str, default='127.0.0.1', help='PostgreSQL server host') parser.add_argument( '--pgport', type=int, default=5432, help='PostgreSQL server port') parser.add_argument( '--pguser', type=str, default='postgres', help='PostgreSQL server user') args = parser.parse_args() asyncio.run(runner(args)) if __name__ == '__main__': main()